Skip to content

[HLSL] Add support to lookup a ResourceBindingInfo from its use #126556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 19, 2025

Conversation

V-FEXrt
Copy link
Contributor

@V-FEXrt V-FEXrt commented Feb 10, 2025

Adds findByUse which takes a llvm::Value from a use and resolves it (as best as possible) back to the creation of that resource.

It may return multiple ResourceBindingInfo if the use comes from branched control flow.

Fixes #125746

Copy link

github-actions bot commented Feb 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@V-FEXrt V-FEXrt marked this pull request as ready for review February 14, 2025 22:25
@llvmbot llvmbot added backend:DirectX llvm:analysis Includes value tracking, cost tables and constant folding labels Feb 14, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-backend-directx

Author: Ashley Coleman (V-FEXrt)

Changes

Adds findByUse which takes a llvm::Value from a use and resolves it (as best as possible) back to the creation of that resource.

It may return multiple ResourceBindingInfo if the use comes from branched control flow.

Fixes #125746


Full diff: https://github.com/llvm/llvm-project/pull/126556.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+4)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+44)
  • (modified) llvm/unittests/Target/DirectX/CMakeLists.txt (+2)
  • (added) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+309)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 87c5615c28ee0..9e1e3a6dfc50b 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -446,6 +446,10 @@ class DXILBindingMap {
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
   }
 
+  // Resoloves the use of a resource handle into the unique description of that
+  // resource by deduping calls to create.
+  SmallVector<dxil::ResourceBindingInfo> findByUse(const Value *Key) const;
+
   const_iterator find(const CallInst *Key) const {
     auto Pos = CallMap.find(Key);
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 7f28e63cc117d..25ff7db7a4d71 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -770,6 +770,50 @@ void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
   }
 }
 
+SmallVector<dxil::ResourceBindingInfo>
+DXILBindingMap::findByUse(const Value *Key) const {
+  const PHINode *Phi = dyn_cast<PHINode>(Key);
+  if (Phi) {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : Phi->operands()) {
+      Children.append(findByUse(V));
+    }
+    return Children;
+  }
+
+  const CallInst *CI = dyn_cast<CallInst>(Key);
+  if (!CI) {
+    return {};
+  }
+
+  const Type *UseType = CI->getType();
+
+  switch (CI->getIntrinsicID()) {
+  // Check if any of the parameters are the resource we are following. If so
+  // keep searching
+  case Intrinsic::not_intrinsic: {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : CI->args()) {
+      if (V->getType() != UseType) {
+        continue;
+      }
+
+      Children.append(findByUse(V));
+    }
+
+    return Children;
+  }
+  // Found the create, return the binding
+  case Intrinsic::dx_resource_handlefrombinding:
+    const auto *It = find(CI);
+    if (It == Infos.end())
+      return {};
+    return {*It};
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 
 AnalysisKey DXILResourceTypeAnalysis::Key;
diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt
index 626c0d6384268..fd0d5a0dd52c1 100644
--- a/llvm/unittests/Target/DirectX/CMakeLists.txt
+++ b/llvm/unittests/Target/DirectX/CMakeLists.txt
@@ -8,10 +8,12 @@ set(LLVM_LINK_COMPONENTS
   Core
   DirectXCodeGen
   DirectXPointerTypeAnalysis
+  Passes
   Support
   )
 
 add_llvm_target_unittest(DirectXTests
   CBufferDataLayoutTests.cpp
   PointerTypeAnalysisTests.cpp
+  UniqueResourceFromUseTests.cpp
   )
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
new file mode 100644
index 0000000000000..5ad7330f05a45
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -0,0 +1,309 @@
+//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DirectXIRPasses/PointerTypeAnalysis.h"
+#include "DirectXTargetMachine.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/Debugify.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+namespace {
+class UniqueResourceFromUseTest : public testing::Test {
+protected:
+  PassBuilder *PB;
+  ModuleAnalysisManager *MAM;
+
+  virtual void SetUp() {
+    MAM = new ModuleAnalysisManager();
+    PB = new PassBuilder();
+    PB->registerModuleAnalyses(*MAM);
+    MAM->registerPass([&] { return DXILResourceTypeAnalysis(); });
+    MAM->registerPass([&] { return DXILResourceBindingAnalysis(); });
+  }
+
+  virtual void TearDown() {
+    delete PB;
+    delete MAM;
+  }
+};
+
+TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(2u, CalledResources)
+        << "Expected 2 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
+  %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle4)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
+  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
+  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
+  %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 4u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(2u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(2u, Binding.Size);
+
+      Binding = Bindings[2].getBinding();
+      EXPECT_EQ(2u, Binding.RecordID);
+      EXPECT_EQ(3u, Binding.Space);
+      EXPECT_EQ(3u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      Binding = Bindings[3].getBinding();
+      EXPECT_EQ(3u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
+  StringRef Assembly = R"(
+define void @foo(i32 %n) {
+entry:
+  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %cond = icmp eq i32 %n, 0
+  br i1 %cond, label %bb.true, label %bb.false
+
+bb.true:
+  %handle_t = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  br label %bb.exit
+
+bb.false:
+  %handle_f = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %y)
+  br label %bb.exit
+
+bb.exit:
+  %handle = phi target("dx.RawBuffer", float, 1, 0) [ %handle_t, %bb.true ], [ %handle_f, %bb.false ]
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 2u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+} // namespace

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Ashley Coleman (V-FEXrt)

Changes

Adds findByUse which takes a llvm::Value from a use and resolves it (as best as possible) back to the creation of that resource.

It may return multiple ResourceBindingInfo if the use comes from branched control flow.

Fixes #125746


Full diff: https://github.com/llvm/llvm-project/pull/126556.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+4)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+44)
  • (modified) llvm/unittests/Target/DirectX/CMakeLists.txt (+2)
  • (added) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+309)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 87c5615c28ee0..9e1e3a6dfc50b 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -446,6 +446,10 @@ class DXILBindingMap {
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
   }
 
+  // Resoloves the use of a resource handle into the unique description of that
+  // resource by deduping calls to create.
+  SmallVector<dxil::ResourceBindingInfo> findByUse(const Value *Key) const;
+
   const_iterator find(const CallInst *Key) const {
     auto Pos = CallMap.find(Key);
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 7f28e63cc117d..25ff7db7a4d71 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -770,6 +770,50 @@ void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
   }
 }
 
+SmallVector<dxil::ResourceBindingInfo>
+DXILBindingMap::findByUse(const Value *Key) const {
+  const PHINode *Phi = dyn_cast<PHINode>(Key);
+  if (Phi) {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : Phi->operands()) {
+      Children.append(findByUse(V));
+    }
+    return Children;
+  }
+
+  const CallInst *CI = dyn_cast<CallInst>(Key);
+  if (!CI) {
+    return {};
+  }
+
+  const Type *UseType = CI->getType();
+
+  switch (CI->getIntrinsicID()) {
+  // Check if any of the parameters are the resource we are following. If so
+  // keep searching
+  case Intrinsic::not_intrinsic: {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : CI->args()) {
+      if (V->getType() != UseType) {
+        continue;
+      }
+
+      Children.append(findByUse(V));
+    }
+
+    return Children;
+  }
+  // Found the create, return the binding
+  case Intrinsic::dx_resource_handlefrombinding:
+    const auto *It = find(CI);
+    if (It == Infos.end())
+      return {};
+    return {*It};
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 
 AnalysisKey DXILResourceTypeAnalysis::Key;
diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt
index 626c0d6384268..fd0d5a0dd52c1 100644
--- a/llvm/unittests/Target/DirectX/CMakeLists.txt
+++ b/llvm/unittests/Target/DirectX/CMakeLists.txt
@@ -8,10 +8,12 @@ set(LLVM_LINK_COMPONENTS
   Core
   DirectXCodeGen
   DirectXPointerTypeAnalysis
+  Passes
   Support
   )
 
 add_llvm_target_unittest(DirectXTests
   CBufferDataLayoutTests.cpp
   PointerTypeAnalysisTests.cpp
+  UniqueResourceFromUseTests.cpp
   )
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
new file mode 100644
index 0000000000000..5ad7330f05a45
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -0,0 +1,309 @@
+//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DirectXIRPasses/PointerTypeAnalysis.h"
+#include "DirectXTargetMachine.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/Debugify.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+namespace {
+class UniqueResourceFromUseTest : public testing::Test {
+protected:
+  PassBuilder *PB;
+  ModuleAnalysisManager *MAM;
+
+  virtual void SetUp() {
+    MAM = new ModuleAnalysisManager();
+    PB = new PassBuilder();
+    PB->registerModuleAnalyses(*MAM);
+    MAM->registerPass([&] { return DXILResourceTypeAnalysis(); });
+    MAM->registerPass([&] { return DXILResourceBindingAnalysis(); });
+  }
+
+  virtual void TearDown() {
+    delete PB;
+    delete MAM;
+  }
+};
+
+TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(2u, CalledResources)
+        << "Expected 2 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
+  %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle4)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
+  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
+  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
+  %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 4u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(2u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(2u, Binding.Size);
+
+      Binding = Bindings[2].getBinding();
+      EXPECT_EQ(2u, Binding.RecordID);
+      EXPECT_EQ(3u, Binding.Space);
+      EXPECT_EQ(3u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      Binding = Bindings[3].getBinding();
+      EXPECT_EQ(3u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
+  StringRef Assembly = R"(
+define void @foo(i32 %n) {
+entry:
+  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %cond = icmp eq i32 %n, 0
+  br i1 %cond, label %bb.true, label %bb.false
+
+bb.true:
+  %handle_t = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  br label %bb.exit
+
+bb.false:
+  %handle_f = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %y)
+  br label %bb.exit
+
+bb.exit:
+  %handle = phi target("dx.RawBuffer", float, 1, 0) [ %handle_t, %bb.true ], [ %handle_f, %bb.false ]
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 2u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+} // namespace

Copy link
Contributor

@joaosaffran joaosaffran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some minor comment.

@@ -782,9 +781,8 @@ DXILBindingMap::findByUse(const Value *Key) const {
}

const CallInst *CI = dyn_cast<CallInst>(Key);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it impact the recursion if we check for CallInst before PHINode? I kind of like my base cases to be first, but I know thats not always possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to test it but it looks like PHINode isn't a subclass of CallInst right? https://llvm.org/doxygen/classllvm_1_1PHINode.html

Which means the base case would earily exit before the PHINode code had a chance to run. Could maybe tuck it inside the if (!CI) but that feels a bit messsy to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the trivial reorder causes the PHINodes case to fail

@farzonl
Copy link
Member

farzonl commented Feb 18, 2025

LGTM

/// ResourceBindingInfo can be used to depuplicate unique handles that
/// reference the same resource
SmallVector<dxil::ResourceBindingInfo>
findCreationInfo(const Value *Key) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely convinced this name is clearer than the findByValue type name in the original, though I do agree that name is a bit ambiguous as well. I guess we want to convey something like "this looks up the possible resources this could be" - maybe "walkDefinitions" or something could work? In any case I'll leave it to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I need to fix the typo anyways I'll just go back to the original name and keep it in the back of my mind for improvement. We still have to touch the function at least one more time so there will be another opportunity to review the name

@V-FEXrt V-FEXrt merged commit 02c9dae into llvm:main Feb 19, 2025
9 checks passed
@V-FEXrt V-FEXrt deleted the hlsl-125746-find-resource-by-use branch February 19, 2025 00:30
if (const PHINode *Phi = dyn_cast<PHINode>(Key)) {
SmallVector<dxil::ResourceBindingInfo> Children;
for (const Value *V : Phi->operands()) {
Children.append(findByUse(V));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without a set that tracks and skips already visited Value* pointers, this could result in infinite recursion, right?

Copy link
Contributor Author

@V-FEXrt V-FEXrt Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhhh maybe, but I dont think so? Can you formulate an example? I've been trying myself for a bit and everything I come up with violates SSA.

Visited lists are only relevant when cycles are possible and in order to introduce a cycle we need a CallInstr to reference something not yet defined and that's not allowed right?

%bat = call @foo %bar
%bar = call @foo2 %bat

@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 19, 2025

LLVM Buildbot has detected a new failure on builder lld-x86_64-win running on as-worker-93 while building llvm at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/146/builds/2323

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM-Unit :: Support/./SupportTests.exe/38/87' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe-LLVM-Unit-18524-38-87.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=87 GTEST_SHARD_INDEX=38 C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe
--

Script:
--
C:\a\lld-x86_64-win\build\unittests\Support\.\SupportTests.exe --gtest_filter=ProgramEnvTest.CreateProcessLongPath
--
C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(160): error: Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp(163): error: fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied



C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:160
Expected equality of these values:
  0
  RC
    Which is: -2

C:\a\lld-x86_64-win\llvm-project\llvm\unittests\Support\ProgramTest.cpp:163
fs::remove(Twine(LongPath)): did not return errc::success.
error number: 13
error message: permission denied




********************


@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
@damyanp damyanp removed this from HLSL Support Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DirectX] Add API to find a resource binding given a use of a resource
7 participants