diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index aac24f113d891..754ed89814293 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -477,7 +478,8 @@ struct ArithToLLVMConversionPass options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); - mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); + arith::populateCeilFloorDivExpandOpsPatterns(patterns); + arith::populateArithToLLVMConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -503,6 +505,7 @@ struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { + arith::populateCeilFloorDivExpandOpsPatterns(patterns); arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); } }; diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 64c40f1aba43b..a9dcc0a16b3db 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -540,6 +540,68 @@ func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 { // ----- +// CHECK-LABEL: @ceildivsi +// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64 +func.func @ceildivsi(%arg0 : i64) -> i64 { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64 + // CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64 + // CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64 + // CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64 + // CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64 + // CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64 + // CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64 + // CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64 + // CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64 + // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64 + // CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64 + // CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64 + // CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64 + // CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1 + // CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1 + // CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1 + // CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64 + %0 = arith.ceildivsi %arg0, %arg0 : i64 + return %0: i64 +} + +// CHECK-LABEL: @ceildivui +// CHECK-SAME: %[[ARG0:.*]]: i32) -> i32 +func.func @ceildivui(%arg0 : i32) -> i32 { +// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[ARG0]], %[[CST0]] : i32 +// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[SUB0:.*]] = llvm.sub %[[ARG0]], %[[CST1]] : i32 +// CHECK: %[[DIV0:.*]] = llvm.udiv %[[SUB0]], %[[ARG0]] : i32 +// CHECK: %[[ADD0:.*]] = llvm.add %[[DIV0]], %[[CST1]] : i32 +// CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST0]], %[[ADD0]] : i1, i32 + %0 = arith.ceildivui %arg0, %arg0 : i32 + return %0: i32 +} + +// ----- + +// CHECK-LABEL: @floordivsi +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) -> i32 +func.func @floordivsi(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: %[[SDIV:.*]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i32 + // CHECK: %[[MUL0:.*]] = llvm.mul %[[SDIV]], %[[ARG1]] : i32 + // CHECK: %[[CMP0:.*]] = llvm.icmp "ne" %[[ARG0]], %[[MUL0]] : i32 + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST0]] : i32 + // CHECK: %[[CMP2:.*]] = llvm.icmp "slt" %[[ARG1]], %[[CST0]] : i32 + // CHECK: %[[CMP3:.*]] = llvm.icmp "ne" %[[CMP1]], %[[CMP2]] : i1 + // CHECK: %[[AND:.*]] = llvm.and %[[CMP0]], %[[CMP3]] : i1 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: %[[ADD:.*]] = llvm.add %[[SDIV]], %[[CST1]] : i32 + // CHECK: %[[SEL:.*]] = llvm.select %[[AND]], %[[ADD]], %[[SDIV]] : i1, i32 + %0 = arith.floordivsi %arg0, %arg1 : i32 + return %0 : i32 +} + +// ----- + // CHECK-LABEL: @minmaxi func.func @minmaxi(%arg0 : i32, %arg1 : i32) -> i32 { // CHECK: = llvm.intr.smin(%arg0, %arg1) : (i32, i32) -> i32