Skip to content

[RISCV][VP] Introduce vp saturating addition/subtraction and RISC-V support. #82370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 23, 2024

Conversation

yetingk
Copy link
Contributor

@yetingk yetingk commented Feb 20, 2024

This patch also pick the MatchContext framework from DAGCombiner to an indiviual header file to make the framework be used from other files in llvm/lib/CodeGen/SelectionDAG/.

…support.

This patch also pick the MatchContext framework from DAGCombiner to an indiviual
header file to make the framework be used from other files in
llvm/lib/CodeGen/SelectionDAG/.
@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-risc-v

Author: Yeting Kuo (yetingk)

Changes

This patch also pick the MatchContext framework from DAGCombiner to an indiviual header file to make the framework be used from other files in llvm/lib/CodeGen/SelectionDAG/.


Patch is 455.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82370.diff

15 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+199)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+20)
  • (modified) llvm/include/llvm/IR/VPIntrinsics.def (+24)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+2-135)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+23-13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+8-8)
  • (added) llvm/lib/CodeGen/SelectionDAG/MatchContext.h (+175)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+11-1)
  • (added) llvm/test/CodeGen/RISCV/rvv/vsadd-vp.ll (+2015)
  • (added) llvm/test/CodeGen/RISCV/rvv/vsaddu-vp.ll (+2014)
  • (added) llvm/test/CodeGen/RISCV/rvv/vssub-vp.ll (+2067)
  • (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.ll (+2065)
  • (added) llvm/test/CodeGen/RISCV/rvv/vssubu-vp.s ()
  • (modified) llvm/unittests/IR/VPIntrinsicTest.cpp (+8)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index fd2e3aacd0169c..676ba8a41362bf 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -16718,6 +16718,7 @@ an operation is greater than the maximum value, the result is set (or
 "clamped") to this maximum. If it is below the minimum, it is clamped to this
 minimum.
 
+.. _int_sadd_sat:
 
 '``llvm.sadd.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -16767,6 +16768,8 @@ Examples
       %res = call i4 @llvm.sadd.sat.i4(i4 -4, i4 -5)  ; %res = -8
 
 
+.. _int_uadd_sat:
+
 '``llvm.uadd.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -16814,6 +16817,8 @@ Examples
       %res = call i4 @llvm.uadd.sat.i4(i4 8, i4 8)  ; %res = 15
 
 
+.. _int_ssub_sat:
+
 '``llvm.ssub.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -16862,6 +16867,8 @@ Examples
       %res = call i4 @llvm.ssub.sat.i4(i4 4, i4 -5)  ; %res = 7
 
 
+.. _int_usub_sat:
+
 '``llvm.usub.sat.*``' Intrinsics
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -23579,6 +23586,198 @@ Examples:
       %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
 
 
+.. _int_vp_sadd_sat:
+
+'``llvm.vp.sadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.sadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.sadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.sadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.sadd.sat``' intrinsic performs sadd.sat (:ref:`sadd.sat <int_sadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_uadd_sat:
+
+'``llvm.vp.uadd.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.uadd.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.uadd.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.uadd.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating addition of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.uadd.sat``' intrinsic performs uadd.sat (:ref:`uadd.sat <int_uadd_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_ssub_sat:
+
+'``llvm.vp.ssub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.ssub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.ssub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.ssub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated signed saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.ssub.sat``' intrinsic performs ssub.sat (:ref:`ssub.sat <int_ssub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.ssub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
+.. _int_vp_usub_sat:
+
+'``llvm.vp.usub.sat.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <16 x i32>  @llvm.vp.usub.sat.v16i32 (<16 x i32> <left_op> <16 x i32> <right_op>, <16 x i1> <mask>, i32 <vector_length>)
+      declare <vscale x 4 x i32>  @llvm.vp.usub.sat.nxv4i32 (<vscale x 4 x i32> <left_op>, <vscale x 4 x i32> <right_op>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare <256 x i64>  @llvm.vp.usub.sat.v256i64 (<256 x i64> <left_op>, <256 x i64> <right_op>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+Predicated unsigned saturating subtraction of two vectors of integers.
+
+
+Arguments:
+""""""""""
+
+The first two operand and the result have the same vector of integer type. The
+third operand is the vector mask and has the same number of elements as the
+result vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.usub.sat``' intrinsic performs usub.sat (:ref:`usub.sat <int_usub_sat>`) of the first, second,
+vector operand on each enabled lane. The result on disabled lanes is a :ref:`poison value <poisonvalues>`.
+
+
+Examples:
+"""""""""
+
+.. code-block:: llvm
+
+      %r = call <4 x i32> @llvm.vp.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %evl)
+      ;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
+
+      %t = call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> %b)
+      %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+
+
 .. _int_vp_fshl:
 
 '``llvm.vp.fshl.*``' Intrinsics
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8c0d4d5db32d88..d7c1ce153a6c80 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1933,6 +1933,26 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
                                LLVMMatchType<0>,
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
                                llvm_i32_ty]>;
+  def int_vp_sadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_uadd_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_ssub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
+  def int_vp_usub_sat : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
+                             [ LLVMMatchType<0>,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+                               llvm_i32_ty]>;
 
   // Floating-point arithmetic
   def int_vp_fadd : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ],
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 3b32b60609f536..4089acf9ec3f05 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -293,6 +293,30 @@ BEGIN_REGISTER_VP(vp_fshr, 3, 4, VP_FSHR, -1)
 VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshr)
 VP_PROPERTY_FUNCTIONAL_SDOPC(FSHR)
 END_REGISTER_VP(vp_fshr, VP_FSHR)
+
+// llvm.vp.sadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_sadd_sat, 2, 3, VP_SADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(sadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SADDSAT)
+END_REGISTER_VP(vp_sadd_sat, VP_SADDSAT)
+
+// llvm.vp.uadd.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_uadd_sat, 2, 3, VP_UADDSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(uadd_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(UADDSAT)
+END_REGISTER_VP(vp_uadd_sat, VP_UADDSAT)
+
+// llvm.vp.ssub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_ssub_sat, 2, 3, VP_SSUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(ssub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(SSUBSAT)
+END_REGISTER_VP(vp_ssub_sat, VP_SSUBSAT)
+
+// llvm.vp.usub.sat(x,y,mask,vlen)
+BEGIN_REGISTER_VP(vp_usub_sat, 2, 3, VP_USUBSAT, -1)
+VP_PROPERTY_FUNCTIONAL_INTRINSIC(usub_sat)
+VP_PROPERTY_FUNCTIONAL_SDOPC(USUBSAT)
+END_REGISTER_VP(vp_usub_sat, VP_USUBSAT)
 ///// } Integer Arithmetic
 
 ///// Floating-Point Arithmetic {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2a09e44e192979..318e1c12c3d56f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -76,6 +76,8 @@
 #include <utility>
 #include <variant>
 
+#include "MatchContext.h"
+
 using namespace llvm;
 
 #define DEBUG_TYPE "dagcombine"
@@ -888,141 +890,6 @@ class WorklistInserter : public SelectionDAG::DAGUpdateListener {
   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
 };
 
-class EmptyMatchContext {
-  SelectionDAG &DAG;
-  const TargetLowering &TLI;
-
-public:
-  EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI) {}
-
-  bool match(SDValue OpN, unsigned Opcode) const {
-    return Opcode == OpN->getOpcode();
-  }
-
-  // Same as SelectionDAG::getNode().
-  template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
-    return DAG.getNode(std::forward<ArgT>(Args)...);
-  }
-
-  bool isOperationLegalOrCustom(unsigned Op, EVT VT,
-                                bool LegalOnly = false) const {
-    return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
-  }
-};
-
-class VPMatchContext {
-  SelectionDAG &DAG;
-  const TargetLowering &TLI;
-  SDValue RootMaskOp;
-  SDValue RootVectorLenOp;
-
-public:
-  VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
-    assert(Root->isVPOpcode());
-    if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
-      RootMaskOp = Root->getOperand(*RootMaskPos);
-    else if (Root->getOpcode() == ISD::VP_SELECT)
-      RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
-                                          Root->getOperand(0).getValueType());
-
-    if (auto RootVLenPos =
-            ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
-      RootVectorLenOp = Root->getOperand(*RootVLenPos);
-  }
-
-  /// whether \p OpVal is a node that is functionally compatible with the
-  /// NodeType \p Opc
-  bool match(SDValue OpVal, unsigned Opc) const {
-    if (!OpVal->isVPOpcode())
-      return OpVal->getOpcode() == Opc;
-
-    auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
-                                           !OpVal->getFlags().hasNoFPExcept());
-    if (BaseOpc != Opc)
-      return false;
-
-    // Make sure the mask of OpVal is true mask or is same as Root's.
-    unsigned VPOpcode = OpVal->getOpcode();
-    if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
-      SDValue MaskOp = OpVal.getOperand(*MaskPos);
-      if (RootMaskOp != MaskOp &&
-          !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
-        return false;
-    }
-
-    // Make sure the EVL of OpVal is same as Root's.
-    if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
-      if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
-        return false;
-    return true;
-  }
-
-  // Specialize based on number of operands.
-  // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
-  // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
-  // DAG.getNode(Opcode, DL, VT); }
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {Operand, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDValue N3) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp});
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
-                  SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
-                       Flags);
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
-                       Flags);
-  }
-
-  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
-                  SDValue N2, SDValue N3, SDNodeFlags Flags) {
-    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
-    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
-           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
-  }
-
-  bool isOperationLegalOrCustom(unsigned Op, EVT VT,
-                                bool LegalOnly = false) const {
-    unsigned VPOp = ISD::getVPForBaseOpcode(Op);
-    return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
-  }
-};
-
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index a4ba261686c688..b5b33bb279e8b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -217,7 +217,15 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SSUBSAT:
   case ISD::USUBSAT:
   case ISD::SSHLSAT:
-  case ISD::USHLSAT:     Res = PromoteIntRes_ADDSUBSHLSAT(N); break;
+  case ISD::USHLSAT:
+    Res = PromoteIntRes_ADDSUBSHLSAT<EmptyMatchContext>(N);
+    break;
+  case ISD::VP_SADDSAT:
+  case ISD::VP_UADDSAT:
+  case ISD::VP_SSUBSAT:
+  case ISD::VP_USUBSAT:
+    Res = PromoteIntRes_ADDSUBSHLSAT<VPMatchContext>(N);
+    break;
 
   case ISD::SMULFIX:
   case ISD::SMULFIXSAT:
@@ -934,6 +942,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
   return DAG.getBoolExtOrTrunc(Res.getValue(1), dl, NVT, VT);
 }
 
+template <class MatchContextClass>
 SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   // If the promoted type is legal, we can convert this to:
   //   1. ANY_EXTEND iN to iM
@@ -945,9 +954,10 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   SDLoc dl(N);
   SDValue Op1 = N->getOperand(0);
   SDValue Op2 = N->getOperand(1);
+  MatchContextClass matcher(DAG, TLI, N);
   unsigned OldBits = Op1.getScalarValueSizeInBits();
 
-  unsigned Opcode = N->getOpcode();
+  unsigned Opcode = matcher.getRootBaseOpcode();
   bool IsShift = Opcode == ISD::USHLSAT || Opcode == ISD::SSHLSAT;
 
   SDValue Op1Promoted, Op2Promoted;
@@ -968,18 +978,18 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
     APInt MaxVal = APInt::getAllOnes(OldBits).zext(NewBits);
     SDValue SatMax = DAG.getConstant(MaxVal, dl, PromotedType);
     SDValue Add =
-        DAG.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
-    return DAG.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
+        matcher.getNode(ISD::ADD, dl, PromotedType, Op1Promoted, Op2Promoted);
+    return matcher.getNode(ISD::UMIN, dl, PromotedType, Add, SatMax);
   }
 
   // USUBSAT can always be promoted as long as we have zero-extended the args.
   if (Opcode == ISD::USUBSAT)
-    return DAG.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
-                       Op2Promoted);
+    return matcher.getNode(ISD::USUBSAT, dl, PromotedType, Op1Promoted,
+                           Op2Promoted);
 
   // Shift cannot use a min/max expansion, we can't detect overflow if all of
   // the bits have been shifted out.
-  if (IsShift || TLI.isOperationLegal(Opcode, PromotedType)) {
+  if (IsShift || matcher.isOperationLegal(Opcode, PromotedType)) {
     unsigned ShiftOp;
     switch (Opcode) {
     case ISD::SADDSAT:
@@ -1002,11 +1012,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
         DAG.getNode(ISD::SHL, dl, PromotedType, Op1Promoted, ShiftAmount);
     if (!IsShift)
       Op2Promoted =
-          DAG.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
+          matcher.getNode(ISD::SHL, dl, PromotedType, Op2Promoted, ShiftAmount);
 
     SDValue Result =
-        DAG.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
-    return DAG.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
+        matcher.getNode(Opcode, dl, PromotedType, Op1Promoted, Op2Promoted);
+    return matcher.getNode(ShiftOp, dl, PromotedType, Result, ShiftAmount);
   }
 
   unsigned AddOp = Opcode == ISD::SADDSAT ? ISD::ADD : ISD::SUB;
@@ -1015,9 +1025,9 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBSHLSAT(SDNode *N) {
   SDValue SatMin = DAG.getConstant(MinVal, dl, PromotedType);
   SDValue SatMax = DAG.getConstant(MaxVal,...
[truncated]

Copy link

github-actions bot commented Feb 20, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff bb049094d5d32e562e5831e5bd0fa12c0a7984ce 4ecfaa50170daa0c5be6840516561f6d674e91f7 -- llvm/lib/CodeGen/SelectionDAG/MatchContext.h llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp llvm/lib/Target/RISCV/RISCVISelLowering.cpp llvm/unittests/IR/VPIntrinsicTest.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 90cda2a115..7ccd27e05d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1163,10 +1163,14 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::SMAX: case ISD::VP_SMAX:
   case ISD::UMIN: case ISD::VP_UMIN:
   case ISD::UMAX: case ISD::VP_UMAX:
-  case ISD::SADDSAT: case ISD::VP_SADDSAT:
-  case ISD::UADDSAT: case ISD::VP_UADDSAT:
-  case ISD::SSUBSAT: case ISD::VP_SSUBSAT:
-  case ISD::USUBSAT: case ISD::VP_USUBSAT:
+  case ISD::SADDSAT:
+  case ISD::VP_SADDSAT:
+  case ISD::UADDSAT:
+  case ISD::VP_UADDSAT:
+  case ISD::SSUBSAT:
+  case ISD::VP_SSUBSAT:
+  case ISD::USUBSAT:
+  case ISD::VP_USUBSAT:
   case ISD::SSHLSAT:
   case ISD::USHLSAT:
   case ISD::ROTL:
@@ -4140,10 +4144,14 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::SMAX: case ISD::VP_SMAX:
   case ISD::UMIN: case ISD::VP_UMIN:
   case ISD::UMAX: case ISD::VP_UMAX:
-  case ISD::UADDSAT: case ISD::VP_UADDSAT:
-  case ISD::SADDSAT: case ISD::VP_SADDSAT:
-  case ISD::USUBSAT: case ISD::VP_USUBSAT:
-  case ISD::SSUBSAT: case ISD::VP_SSUBSAT:
+  case ISD::UADDSAT:
+  case ISD::VP_UADDSAT:
+  case ISD::SADDSAT:
+  case ISD::VP_SADDSAT:
+  case ISD::USUBSAT:
+  case ISD::VP_USUBSAT:
+  case ISD::SSUBSAT:
+  case ISD::VP_SSUBSAT:
   case ISD::SSHLSAT:
   case ISD::USHLSAT:
   case ISD::ROTL:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 91264d8ce2..729cd61e91 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -677,21 +677,46 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
                        MVT::Other, Custom);
 
-    static const unsigned IntegerVPOps[] = {
-        ISD::VP_ADD,         ISD::VP_SUB,         ISD::VP_MUL,
-        ISD::VP_SDIV,        ISD::VP_UDIV,        ISD::VP_SREM,
-        ISD::VP_UREM,        ISD::VP_AND,         ISD::VP_OR,
-        ISD::VP_XOR,         ISD::VP_ASHR,        ISD::VP_LSHR,
-        ISD::VP_SHL,         ISD::VP_REDUCE_ADD,  ISD::VP_REDUCE_AND,
-        ISD::VP_REDUCE_OR,   ISD::VP_REDUCE_XOR,  ISD::VP_REDUCE_SMAX,
-        ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
-        ISD::VP_MERGE,       ISD::VP_SELECT,      ISD::VP_FP_TO_SINT,
-        ISD::VP_FP_TO_UINT,  ISD::VP_SETCC,       ISD::VP_SIGN_EXTEND,
-        ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE,    ISD::VP_SMIN,
-        ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
-        ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
-        ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT};
+    static const unsigned IntegerVPOps[] = {ISD::VP_ADD,
+                                            ISD::VP_SUB,
+                                            ISD::VP_MUL,
+                                            ISD::VP_SDIV,
+                                            ISD::VP_UDIV,
+                                            ISD::VP_SREM,
+                                            ISD::VP_UREM,
+                                            ISD::VP_AND,
+                                            ISD::VP_OR,
+                                            ISD::VP_XOR,
+                                            ISD::VP_ASHR,
+                                            ISD::VP_LSHR,
+                                            ISD::VP_SHL,
+                                            ISD::VP_REDUCE_ADD,
+                                            ISD::VP_REDUCE_AND,
+                                            ISD::VP_REDUCE_OR,
+                                            ISD::VP_REDUCE_XOR,
+                                            ISD::VP_REDUCE_SMAX,
+                                            ISD::VP_REDUCE_SMIN,
+                                            ISD::VP_REDUCE_UMAX,
+                                            ISD::VP_REDUCE_UMIN,
+                                            ISD::VP_MERGE,
+                                            ISD::VP_SELECT,
+                                            ISD::VP_FP_TO_SINT,
+                                            ISD::VP_FP_TO_UINT,
+                                            ISD::VP_SETCC,
+                                            ISD::VP_SIGN_EXTEND,
+                                            ISD::VP_ZERO_EXTEND,
+                                            ISD::VP_TRUNCATE,
+                                            ISD::VP_SMIN,
+                                            ISD::VP_SMAX,
+                                            ISD::VP_UMIN,
+                                            ISD::VP_UMAX,
+                                            ISD::VP_ABS,
+                                            ISD::EXPERIMENTAL_VP_REVERSE,
+                                            ISD::EXPERIMENTAL_VP_SPLICE,
+                                            ISD::VP_SADDSAT,
+                                            ISD::VP_UADDSAT,
+                                            ISD::VP_SSUBSAT,
+                                            ISD::VP_USUBSAT};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,

unsigned OldBits = Op1.getScalarValueSizeInBits();

unsigned Opcode = N->getOpcode();
unsigned Opcode = matcher.getRootBaseOpcode();
bool IsShift = Opcode == ISD::USHLSAT || Opcode == ISD::SSHLSAT;

SDValue Op1Promoted, Op2Promoted;
Copy link
Collaborator

@topperc topperc Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have VP aware versions of ZExtPromotedInteger and SExtPromotedInteger?

I guess that would require a VP_SIGN_EXTEND_INREG so maybe for a future patch. Or we'd have to emit a pair of VP_SHL and VP_SRA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a FIXME for this instead to fix it in this PR. Is it Ok?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Arguments:
""""""""""

The first two operand and the result have the same vector of integer type. The
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"two operand" -> "two operands". Is this a typo in existing descriptions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. It's just my typo.

Semantics:
""""""""""

The '``llvm.vp.sadd.sat``' intrinsic performs sadd.sat (:ref:`sadd.sat <int_sadd_sat>`) of the first, second,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"first, second, vector operand -> "first and second vector operands"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc changed the title [RISCV][VP] Introduce vp saturating addition/substraction and RISC-V support. [RISCV][VP] Introduce vp saturating addition/subtraction and RISC-V support. Feb 23, 2024
@yetingk yetingk merged commit 850dde0 into llvm:main Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants