diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 85ec288268aeb..97da96afac4cd 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -18,6 +18,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/FloatingPointMode.h" + namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -286,6 +288,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { } }; +struct IsNaNOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.getOperand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), adaptor.getOperand(), llvm::fcNan); + return success(); + } +}; + +struct IsFiniteOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.getOperand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), adaptor.getOperand(), llvm::fcFinite); + return success(); + } +}; + struct ConvertMathToLLVMPass : public impl::ConvertMathToLLVMPassBase { using Base::Base; @@ -309,6 +345,8 @@ void mlir::populateMathToLLVMConversionPatterns( patterns.add(converter, benefit); // clang-format off patterns.add< + IsNaNOpLowering, + IsFiniteOpLowering, AbsFOpLowering, AbsIOpLowering, CeilOpLowering, diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index 45a37af293890..974743a55932b 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -263,6 +263,26 @@ func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> { // ----- +// CHECK-LABEL: func @isnan_double( +// CHECK-SAME: f64 +func.func @isnan_double(%arg0 : f64) { + // CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 3 : i32}> : (f64) -> i1 + %0 = math.isnan %arg0 : f64 + func.return +} + +// ----- + +// CHECK-LABEL: func @isfinite_double( +// CHECK-SAME: f64 +func.func @isfinite_double(%arg0 : f64) { + // CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 504 : i32}> : (f64) -> i1 + %0 = math.isfinite %arg0 : f64 + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt_double( // CHECK-SAME: f64 func.func @rsqrt_double(%arg0 : f64) {