Skip to content

Commit 8d16513

Browse files
authored
[mlir][[spirv] Add support for math.log2 and math.log10 to GLSL/OpenCL SPIRV Backends (llvm#104608)
As log2 and log10 are not available in spirv, realize them as a decomposition using spirv.CL.log/spirv.GL.Log.
1 parent 08201cb commit 8d16513

File tree

3 files changed

+123
-28
lines changed

3 files changed

+123
-28
lines changed

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,65 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
291291
}
292292
};
293293

294+
/// Converts math.log2 and math.log10 to SPIR-V ops.
295+
///
296+
/// SPIR-V does not have direct operations for log2 and log10. Explicitly
297+
/// lower to these operations using:
298+
/// log2(x) = log(x) * 1/log(2)
299+
/// log10(x) = log(x) * 1/log(10)
300+
301+
template <typename MathLogOp, typename SpirvLogOp>
302+
struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
303+
using OpConversionPattern<MathLogOp>::OpConversionPattern;
304+
using typename OpConversionPattern<MathLogOp>::OpAdaptor;
305+
306+
static constexpr double log2Reciprocal =
307+
1.442695040888963407359924681001892137426645954152985934135449407;
308+
static constexpr double log10Reciprocal =
309+
0.4342944819032518276511289189166050822943970058036665661144537832;
310+
311+
LogicalResult
312+
matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313+
ConversionPatternRewriter &rewriter) const override {
314+
assert(adaptor.getOperands().size() == 1);
315+
if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
316+
failed(res))
317+
return res;
318+
319+
Location loc = operation.getLoc();
320+
Type type = this->getTypeConverter()->convertType(operation.getType());
321+
if (!type)
322+
return rewriter.notifyMatchFailure(operation, "type conversion failed");
323+
324+
auto getConstantValue = [&](double value) {
325+
if (auto floatType = dyn_cast<FloatType>(type)) {
326+
return rewriter.create<spirv::ConstantOp>(
327+
loc, type, rewriter.getFloatAttr(floatType, value));
328+
}
329+
if (auto vectorType = dyn_cast<VectorType>(type)) {
330+
Type elemType = vectorType.getElementType();
331+
332+
if (isa<FloatType>(elemType)) {
333+
return rewriter.create<spirv::ConstantOp>(
334+
loc, type,
335+
DenseFPElementsAttr::get(
336+
vectorType, FloatAttr::get(elemType, value).getValue()));
337+
}
338+
}
339+
340+
llvm_unreachable("unimplemented types for log2/log10");
341+
};
342+
343+
Value constantValue = getConstantValue(
344+
std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345+
: log10Reciprocal);
346+
Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand());
347+
rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
348+
constantValue);
349+
return success();
350+
}
351+
};
352+
294353
/// Converts math.powf to SPIRV-Ops.
295354
struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
296355
using OpConversionPattern::OpConversionPattern;
@@ -411,6 +470,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
411470
// GLSL patterns
412471
patterns
413472
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
473+
Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
474+
Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
414475
ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
415476
CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
416477
CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
@@ -430,6 +491,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
430491

431492
// OpenCL patterns
432493
patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
494+
Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
495+
Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
433496
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
434497
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
435498
CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,

mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,30 @@ func.func @float32_unary_scalar(%arg0: f32) {
2222
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
2323
// CHECK: spirv.GL.Log %[[ADDONE]]
2424
%5 = math.log1p %arg0 : f32
25+
// CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
26+
// CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
27+
// CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
28+
%6 = math.log2 %arg0 : f32
29+
// CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
30+
// CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
31+
// CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
32+
%7 = math.log10 %arg0 : f32
2533
// CHECK: spirv.GL.RoundEven %{{.*}}: f32
26-
%6 = math.roundeven %arg0 : f32
34+
%8 = math.roundeven %arg0 : f32
2735
// CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
28-
%7 = math.rsqrt %arg0 : f32
36+
%9 = math.rsqrt %arg0 : f32
2937
// CHECK: spirv.GL.Sqrt %{{.*}}: f32
30-
%8 = math.sqrt %arg0 : f32
38+
%10 = math.sqrt %arg0 : f32
3139
// CHECK: spirv.GL.Tanh %{{.*}}: f32
32-
%9 = math.tanh %arg0 : f32
40+
%11 = math.tanh %arg0 : f32
3341
// CHECK: spirv.GL.Sin %{{.*}}: f32
34-
%10 = math.sin %arg0 : f32
42+
%12 = math.sin %arg0 : f32
3543
// CHECK: spirv.GL.FAbs %{{.*}}: f32
36-
%11 = math.absf %arg0 : f32
44+
%13 = math.absf %arg0 : f32
3745
// CHECK: spirv.GL.Ceil %{{.*}}: f32
38-
%12 = math.ceil %arg0 : f32
46+
%14 = math.ceil %arg0 : f32
3947
// CHECK: spirv.GL.Floor %{{.*}}: f32
40-
%13 = math.floor %arg0 : f32
48+
%15 = math.floor %arg0 : f32
4149
return
4250
}
4351

@@ -59,16 +67,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
5967
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
6068
// CHECK: spirv.GL.Log %[[ADDONE]]
6169
%5 = math.log1p %arg0 : vector<3xf32>
70+
// CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
71+
// CHECK: %[[LOG0:.+]] = spirv.GL.Log {{.+}}
72+
// CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
73+
%6 = math.log2 %arg0 : vector<3xf32>
74+
// CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
75+
// CHECK: %[[LOG1:.+]] = spirv.GL.Log {{.+}}
76+
// CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
77+
%7 = math.log10 %arg0 : vector<3xf32>
6278
// CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
63-
%6 = math.roundeven %arg0 : vector<3xf32>
79+
%8 = math.roundeven %arg0 : vector<3xf32>
6480
// CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
65-
%7 = math.rsqrt %arg0 : vector<3xf32>
81+
%9 = math.rsqrt %arg0 : vector<3xf32>
6682
// CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
67-
%8 = math.sqrt %arg0 : vector<3xf32>
83+
%10 = math.sqrt %arg0 : vector<3xf32>
6884
// CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
69-
%9 = math.tanh %arg0 : vector<3xf32>
85+
%11 = math.tanh %arg0 : vector<3xf32>
7086
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
71-
%10 = math.sin %arg0 : vector<3xf32>
87+
%12 = math.sin %arg0 : vector<3xf32>
7288
return
7389
}
7490

mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,34 @@ func.func @float32_unary_scalar(%arg0: f32) {
2020
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
2121
// CHECK: spirv.CL.log %[[ADDONE]]
2222
%5 = math.log1p %arg0 : f32
23+
// CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
24+
// CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
25+
// CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
26+
%6 = math.log2 %arg0 : f32
27+
// CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
28+
// CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
29+
// CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
30+
%7 = math.log10 %arg0 : f32
2331
// CHECK: spirv.CL.rint %{{.*}}: f32
24-
%6 = math.roundeven %arg0 : f32
32+
%8 = math.roundeven %arg0 : f32
2533
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
26-
%7 = math.rsqrt %arg0 : f32
34+
%9 = math.rsqrt %arg0 : f32
2735
// CHECK: spirv.CL.sqrt %{{.*}}: f32
28-
%8 = math.sqrt %arg0 : f32
36+
%10 = math.sqrt %arg0 : f32
2937
// CHECK: spirv.CL.tanh %{{.*}}: f32
30-
%9 = math.tanh %arg0 : f32
38+
%11 = math.tanh %arg0 : f32
3139
// CHECK: spirv.CL.sin %{{.*}}: f32
32-
%10 = math.sin %arg0 : f32
40+
%12 = math.sin %arg0 : f32
3341
// CHECK: spirv.CL.fabs %{{.*}}: f32
34-
%11 = math.absf %arg0 : f32
42+
%13 = math.absf %arg0 : f32
3543
// CHECK: spirv.CL.ceil %{{.*}}: f32
36-
%12 = math.ceil %arg0 : f32
44+
%14 = math.ceil %arg0 : f32
3745
// CHECK: spirv.CL.floor %{{.*}}: f32
38-
%13 = math.floor %arg0 : f32
46+
%15 = math.floor %arg0 : f32
3947
// CHECK: spirv.CL.erf %{{.*}}: f32
40-
%14 = math.erf %arg0 : f32
48+
%16 = math.erf %arg0 : f32
4149
// CHECK: spirv.CL.round %{{.*}}: f32
42-
%15 = math.round %arg0 : f32
50+
%17 = math.round %arg0 : f32
4351
return
4452
}
4553

@@ -61,16 +69,24 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
6169
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
6270
// CHECK: spirv.CL.log %[[ADDONE]]
6371
%5 = math.log1p %arg0 : vector<3xf32>
72+
// CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
73+
// CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
74+
// CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
75+
%6 = math.log2 %arg0 : vector<3xf32>
76+
// CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
77+
// CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
78+
// CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
79+
%7 = math.log10 %arg0 : vector<3xf32>
6480
// CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
65-
%6 = math.roundeven %arg0 : vector<3xf32>
81+
%8 = math.roundeven %arg0 : vector<3xf32>
6682
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
67-
%7 = math.rsqrt %arg0 : vector<3xf32>
83+
%9 = math.rsqrt %arg0 : vector<3xf32>
6884
// CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
69-
%8 = math.sqrt %arg0 : vector<3xf32>
85+
%10 = math.sqrt %arg0 : vector<3xf32>
7086
// CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
71-
%9 = math.tanh %arg0 : vector<3xf32>
87+
%11 = math.tanh %arg0 : vector<3xf32>
7288
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
73-
%10 = math.sin %arg0 : vector<3xf32>
89+
%12 = math.sin %arg0 : vector<3xf32>
7490
return
7591
}
7692

0 commit comments

Comments
 (0)