@@ -291,6 +291,65 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
291
291
}
292
292
};
293
293
294
+ // / Converts math.log2 and math.log10 to SPIR-V ops.
295
+ // /
296
+ // / SPIR-V does not have direct operations for log2 and log10. Explicitly
297
+ // / lower to these operations using:
298
+ // / log2(x) = log(x) * 1/log(2)
299
+ // / log10(x) = log(x) * 1/log(10)
300
+
301
+ template <typename MathLogOp, typename SpirvLogOp>
302
+ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> {
303
+ using OpConversionPattern<MathLogOp>::OpConversionPattern;
304
+ using typename OpConversionPattern<MathLogOp>::OpAdaptor;
305
+
306
+ static constexpr double log2Reciprocal =
307
+ 1.442695040888963407359924681001892137426645954152985934135449407 ;
308
+ static constexpr double log10Reciprocal =
309
+ 0.4342944819032518276511289189166050822943970058036665661144537832 ;
310
+
311
+ LogicalResult
312
+ matchAndRewrite (MathLogOp operation, OpAdaptor adaptor,
313
+ ConversionPatternRewriter &rewriter) const override {
314
+ assert (adaptor.getOperands ().size () == 1 );
315
+ if (LogicalResult res = checkSourceOpTypes (rewriter, operation);
316
+ failed (res))
317
+ return res;
318
+
319
+ Location loc = operation.getLoc ();
320
+ Type type = this ->getTypeConverter ()->convertType (operation.getType ());
321
+ if (!type)
322
+ return rewriter.notifyMatchFailure (operation, " type conversion failed" );
323
+
324
+ auto getConstantValue = [&](double value) {
325
+ if (auto floatType = dyn_cast<FloatType>(type)) {
326
+ return rewriter.create <spirv::ConstantOp>(
327
+ loc, type, rewriter.getFloatAttr (floatType, value));
328
+ }
329
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
330
+ Type elemType = vectorType.getElementType ();
331
+
332
+ if (isa<FloatType>(elemType)) {
333
+ return rewriter.create <spirv::ConstantOp>(
334
+ loc, type,
335
+ DenseFPElementsAttr::get (
336
+ vectorType, FloatAttr::get (elemType, value).getValue ()));
337
+ }
338
+ }
339
+
340
+ llvm_unreachable (" unimplemented types for log2/log10" );
341
+ };
342
+
343
+ Value constantValue = getConstantValue (
344
+ std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345
+ : log10Reciprocal);
346
+ Value log = rewriter.create <SpirvLogOp>(loc, adaptor.getOperand ());
347
+ rewriter.replaceOpWithNewOp <spirv::FMulOp>(operation, type, log ,
348
+ constantValue);
349
+ return success ();
350
+ }
351
+ };
352
+
294
353
// / Converts math.powf to SPIRV-Ops.
295
354
struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
296
355
using OpConversionPattern::OpConversionPattern;
@@ -411,6 +470,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
411
470
// GLSL patterns
412
471
patterns
413
472
.add <CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
473
+ Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
474
+ Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
414
475
ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
415
476
CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
416
477
CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
@@ -430,6 +491,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
430
491
431
492
// OpenCL patterns
432
493
patterns.add <Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
494
+ Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
495
+ Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
433
496
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
434
497
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
435
498
CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
0 commit comments