From c7ade67e715206eb20207a7fe84b9807e3f22065 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 5 May 2025 16:45:11 -0700 Subject: [PATCH 01/10] green commit tracker --- deps.bzl | 28 +++++++++++++++++++------ lib/Pipeline/Pipeline.cpp | 10 ++++----- test/Pipeline/tcp_to_llvm_pipeline.mlir | 1 - 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/deps.bzl b/deps.bzl index 8fcf9f5..4216d0d 100644 --- a/deps.bzl +++ b/deps.bzl @@ -22,8 +22,8 @@ def third_party_deps(): path = local_llvm_repo_path(), ) else: - LLVM_COMMIT = "5d6d982df61d16b6d498e6d59dd91c059679d3d8" - LLVM_SHA256 = "834184126812eecbdb2ed30de255554a6529295afaf44e9dfd3851d61195dbb5" + LLVM_COMMIT = "72144d119a7291f8b6b8e022a2947fbe31e66afc" + LLVM_SHA256 = "2caacb6925a13cb5886a5d7f225fa408b80ca8e1efe0736186954b2abc4ee1c3" http_archive( name = "llvm-raw", build_file_content = "# empty", @@ -39,8 +39,8 @@ def third_party_deps(): path = local_torch_mlir_repo_path(), ) else: - TORCH_MLIR_COMMIT = "169032010793ee7fe3e305ab920e4119fdfc3b11" - TORCH_MLIR_SHA256 = "0f25459b0d6828983c8aa78d139adad4325508bff150b57e97345e9798377dd3" + TORCH_MLIR_COMMIT = "9f2ba5abaa85cefd95cc85579fafd0c53c1101e8" + TORCH_MLIR_SHA256 = "09444281839eeae4aff42c029d87b1728f307fa26511b896ff448d51aaa98049" http_archive( name = "torch-mlir-raw", build_file_content = "# empty", @@ -55,8 +55,8 @@ def third_party_deps(): path = local_stablehlo_repo_path(), ) else: - STABLEHLO_COMMIT = "b62dc66da9946b4c400c0d99c9d5bb8e04edaee6" - STABLEHLO_SHA256 = "a51842f5cbcccc2dc74de232793e6fdc0b4403b616281a73bbc704cd227b50db" + STABLEHLO_COMMIT = "a54938f0651d3b4b7be9771848eda2463c92a8e7" + STABLEHLO_SHA256 = "edab2288f0b19e3efbf08815d17d4efb106984aa6fe02fed0cb2165284e6a5b7" http_archive( name = "stablehlo", sha256 = STABLEHLO_SHA256, @@ -168,3 +168,19 @@ def third_party_deps(): strip_prefix = "cnpy-4e8810b1a8637695171ed346ce68f6984e585ef4", urls = ["https://github.com/rogersce/cnpy/archive/4e8810b1a8637695171ed346ce68f6984e585ef4.tar.gz"], ) + + http_archive( + name = "nanobind", + build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD", + sha256 = "bb35deaed7efac5029ed1e33880a415638352f757d49207a8e6013fefb6c49a7", + strip_prefix = "nanobind-2.4.0", + url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.4.0.tar.gz", + ) + + http_archive( + name = "robin_map", + build_file = "@llvm-raw//utils/bazel/third_party_build:robin_map.BUILD", + sha256 = "a8424ad3b0affd4c57ed26f0f3d8a29604f0e1f2ef2089f497f614b1c94c7236", + strip_prefix = "robin-map-1.3.0", + url = "https://github.com/Tessil/robin-map/archive/refs/tags/v1.3.0.tar.gz", + ) diff --git a/lib/Pipeline/Pipeline.cpp b/lib/Pipeline/Pipeline.cpp index e044057..6ff1d68 100644 --- a/lib/Pipeline/Pipeline.cpp +++ b/lib/Pipeline/Pipeline.cpp @@ -80,10 +80,10 @@ static void createTcpToLlvmPipeline(OpPassManager &pm) { // One-shot bufferize tensor -> memref, from // https://mlir.llvm.org/docs/Bufferization/. - bufferization::OneShotBufferizationOptions bufferizationOptions; + bufferization::OneShotBufferizePassOptions bufferizationOptions; bufferizationOptions.bufferizeFunctionBoundaries = true; - bufferizationOptions.setFunctionBoundaryTypeConversion( - bufferization::LayoutMapOption::IdentityLayoutMap); + bufferizationOptions.functionBoundaryTypeConversion = + bufferization::LayoutMapOption::IdentityLayoutMap; pm.addPass(bufferization::createOneShotBufferizePass(bufferizationOptions)); // Buffer deallocation pipeline for automatically inserting // buffer deallocation ops after one-shot bufferization. @@ -95,14 +95,14 @@ static void createTcpToLlvmPipeline(OpPassManager &pm) { pm.addPass(bufferization::createLowerDeallocationsPass()); pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); - pm.addPass(createBufferizationToMemRefPass()); + pm.addPass(createConvertBufferizationToMemRefPass()); // Blanket-convert any remaining linalg ops to loops if any remain. pm.addNestedPass(createConvertLinalgToLoopsPass()); // Blanket-convert any remaining affine ops if any remain. pm.addPass(createLowerAffinePass()); // Convert SCF to CF (always needed). - pm.addPass(createConvertSCFToCFPass()); + pm.addPass(createSCFToControlFlowPass()); // Sprinkle some cleanups. pm.addPass(createCanonicalizerPass()); diff --git a/test/Pipeline/tcp_to_llvm_pipeline.mlir b/test/Pipeline/tcp_to_llvm_pipeline.mlir index b64a7ca..7a64efb 100644 --- a/test/Pipeline/tcp_to_llvm_pipeline.mlir +++ b/test/Pipeline/tcp_to_llvm_pipeline.mlir @@ -2,7 +2,6 @@ // CHECK-LABEL: llvm.func @main // CHECK: llvm.mlir.constant -// CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue // CHECK: llvm.extractvalue // CHECK: llvm.alloca From 64928d61d1e745e11a3e44641ce32ad972f371b7 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 5 May 2025 16:59:44 -0700 Subject: [PATCH 02/10] buildifier --- deps.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps.bzl b/deps.bzl index 4216d0d..cc03944 100644 --- a/deps.bzl +++ b/deps.bzl @@ -168,7 +168,7 @@ def third_party_deps(): strip_prefix = "cnpy-4e8810b1a8637695171ed346ce68f6984e585ef4", urls = ["https://github.com/rogersce/cnpy/archive/4e8810b1a8637695171ed346ce68f6984e585ef4.tar.gz"], ) - + http_archive( name = "nanobind", build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD", From dcb23deb7bb296a813b6c61fc02da362ce144c94 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 6 May 2025 06:16:00 -0700 Subject: [PATCH 03/10] update requirements_lock.txt because of CI failure --- requirements_lock.txt | 46 ++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/requirements_lock.txt b/requirements_lock.txt index fed32c6..e23abc0 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -157,28 +157,30 @@ sympy==1.13.3 \ --hash=sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73 \ --hash=sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9 # via torch -torch==2.7.0.dev20250221+cpu \ - --hash=sha256:03ed7e7e5186f2f0cadea798d1bbb249ca342618e90542b9634463c194ad8999 \ - --hash=sha256:08c6edef82a3a11afcc8ec86910eb41837ca7e646416280de6ef352f6e216370 \ - --hash=sha256:0fd00b3a1198f610fa706074243097d4448ee0a0efca96b809107c60af0a3f98 \ - --hash=sha256:17672aba174f465fe90bee6972ee9a542980b59f410772e68091a5d044604b9d \ - --hash=sha256:17db41853a494b1eb1b3c23e383297b2e4678d6629dfa539e8a9f71c4b05b32d \ - --hash=sha256:1b3c5bc3a52cdacee11794a5221aea89a6207439a24aedfd458d79b6b5d38ad1 \ - --hash=sha256:4201870a7d363dfb0c2015a26d847f71248685f352234c71834923f1b477b7ed \ - --hash=sha256:494d8fa9c469cdcb042f61cfbf2c505ebfd344e923af8ba7f6223cd8b2a7742c \ - --hash=sha256:5218559bd4c044977b3240aa0a2e188f7694865ef24611430bcb701ec10f5276 \ - --hash=sha256:73f82e45e5f1707100751bd0bedbb8bf242e32268913959dbe9f0e4ab3b3cb99 \ - --hash=sha256:9f80431e71a2e7d7795220eb9d549cf0490e6c5b36f0886539f7c9ef11b23a39 \ - --hash=sha256:a89ed8084b88720fed36655cd8fe49c1b5c135483c61fb9ac5231210a502ef9c \ - --hash=sha256:ac88604bf2dd8e4a53ed07877fb9e845a7c3c8a03aeeefe19f381655f19b056c \ - --hash=sha256:ae3698e5caa6ddf1ad40712924f45788e15e2a838793969b0477868e02953011 \ - --hash=sha256:aee87be29d490521806a414191e9e9afbc27e55167e79dbb66d06c18d87e0079 \ - --hash=sha256:b3a63e0b2e8c495d0781735720a063c8924014cfcae8a6f63a360a696b657730 \ - --hash=sha256:b8ac59ca5484c438a49b4d6a4ad3721256071c80687220ae8ffa355ea5c745ee \ - --hash=sha256:c61d02f308414e7a2b95972e2b950168af86d5b24dcedb332e29ecb8419b64cd \ - --hash=sha256:d9f4236bfb9b4dc39e7569ff5a2fb39f4b28b3f79dc2ff6b6ce70f7ca67fd40a \ - --hash=sha256:f15bc3e3e51227f18068da9cd5b8153734385d817723c153bb0cdea7285b1eae \ - --hash=sha256:f7433418e166b7a3e87e43f3e64115207a84dc50db8dc3b9a51428e737edcac7 +torch==2.8.0.dev20250506+cpu \ + --hash=sha256:02abfdcdbb9ca15e3c561d31b1617f9d88f978af49b3b76cc048a5159c4bbb19 \ + --hash=sha256:0304c11aa1a404a664a776dea4b61dab31707d5fecc1e165ea17b1c780049911 \ + --hash=sha256:081ecdc2ced1285b92cce4684922710af244ccf4e4430d36c746f025e6872a30 \ + --hash=sha256:0bdc6883695004803ea0e062382d21e432168d7ee93e6f77375d34fc43778ca8 \ + --hash=sha256:1c82f3cd449bee2adcfc8c1dc25b087fc3ed9eba239ea46449e1a087ddbf5f97 \ + --hash=sha256:370ae6fb1c8c132c4578973eb6066f14d10fb6cdc05a89e44660fec15bbce9a4 \ + --hash=sha256:3c68844186c4d43db95f096b120b91c530c4e92540eeeece90e59fd6ec078f03 \ + --hash=sha256:4017473f0a77cd2774a3c8245032fb9979ac08f92831f94f70d9e22612e2d5c1 \ + --hash=sha256:4575a76e5459285311d1f94fb8835fec81d5509321192716fcff8631aa258ae3 \ + --hash=sha256:48c682f8f369b573045d5922e989812b77183f4020a750b3339c3e64e42fd733 \ + --hash=sha256:4a64fd103df112e2dbfb00ab04ffef839bc1838caa40ff8bf86647eb39daa7ad \ + --hash=sha256:5f2a251b87dc7a359fe5b83772cb2830e01b0d75a585edc1ffe659a3e59ae17b \ + --hash=sha256:690f44ae8974588810a6c58052e908fb1abc7c3d34e335faccec0baba852596b \ + --hash=sha256:810c8106d575256c6e429e26a8edf58e4ab43fea0b10c4d56eed011f0712ee90 \ + --hash=sha256:8701a35246db0aa148ea3bb6edb022a639c16115912d2dc90cbad9a56c0ded2e \ + --hash=sha256:a5974f2958d12d01577e206417ee4d04dc2f2275505d266323cf23e828e46d96 \ + --hash=sha256:b17959e888c65cef0765bfef3e4813f3dad7d3d55f73c976ca33a47d2ff875b5 \ + --hash=sha256:b91059dce8f9c97fce586b1367c91c64912ba0866e2213510a3ffe522cee3aee \ + --hash=sha256:c8a7058db5c6c478d2f93a14f911dcc045d5470ed0920797ec5a6008a0bce354 \ + --hash=sha256:ce7960db4fb7899626a4a94c361b0d2c80c9b3bd6907b929380a176df27b9908 \ + --hash=sha256:e23ba269a7f189dc65c1b0ff937beb0630dfbe9a810cd307d284a51cbc8409d6 \ + --hash=sha256:e2cccdcc64938ede25afc43efaa4e70fdf45709c3f2b48549adc0d163aa7fadf \ + --hash=sha256:fb30a20142ed498569649208d67f03e9e9f345be79ab340ceec734439a475d9a # via -r requirements.txt torch-mlir==20250127.357 \ --hash=sha256:43c2362b6a5265405ac5d2291982d6b0d83afafc7ee37165f4cc6b845dec4c15 \ From b898a83cea9c0b073e7bd8328ac2303114037fc6 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 6 May 2025 06:16:09 -0700 Subject: [PATCH 04/10] clang-format --- lib/Dialect/IR/TcpOps.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 7dbd2bb..fd01fbf 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -192,9 +192,9 @@ LogicalResult GatherOp::verify() { i == gatherDim)) { std::stringstream ss; ss << "indicies index " << i - << " expected to be less than or equal to input " << " (" - << indicesTensor.getShape()[i] << " <= " << inputTensor.getShape()[i] - << ")"; + << " expected to be less than or equal to input " + << " (" << indicesTensor.getShape()[i] + << " <= " << inputTensor.getShape()[i] << ")"; return emitOpError(ss.str()); } } @@ -293,7 +293,8 @@ ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, void BindSymbolicShapeOp::print(OpAsmPrinter &p) { p << " " << getOperand() << ", ["; llvm::interleaveComma(getShapeSymbols(), p); - p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">"; + p << "], " + << "affine_map<" << getShapeExpressions().getValue() << ">"; p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape_expressions"}); p << " : " << getOperand().getType(); From b4d90378207aec75221a2a4a52a46c7dab2fae62 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Fri, 16 May 2025 08:22:13 -0700 Subject: [PATCH 05/10] patch torch-mlir to enable building stablehlo --- deps.bzl | 3 +++ third_party/patches/BUILD | 10 ++++++++++ third_party/patches/torch-mlir.1.patch | 12 ++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 third_party/patches/BUILD create mode 100644 third_party/patches/torch-mlir.1.patch diff --git a/deps.bzl b/deps.bzl index cc03944..960a560 100644 --- a/deps.bzl +++ b/deps.bzl @@ -47,6 +47,9 @@ def third_party_deps(): sha256 = TORCH_MLIR_SHA256, strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT, urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)], + patches = [ + "//third_party/patches:torch-mlir.1.patch", + ] ) if use_local_stablehlo_repo(): diff --git a/third_party/patches/BUILD b/third_party/patches/BUILD new file mode 100644 index 0000000..da6be20 --- /dev/null +++ b/third_party/patches/BUILD @@ -0,0 +1,10 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +filegroup( + name = "all_files", + srcs = glob(["*"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/patches/torch-mlir.1.patch b/third_party/patches/torch-mlir.1.patch new file mode 100644 index 0000000..0426372 --- /dev/null +++ b/third_party/patches/torch-mlir.1.patch @@ -0,0 +1,12 @@ +diff --git lib/InitAll.cpp lib/InitAll.cpp +index d9096929..2a9be6cc 100644 +--- lib/InitAll.cpp ++++ lib/InitAll.cpp +@@ -33,6 +33,7 @@ + #ifdef TORCH_MLIR_ENABLE_STABLEHLO + #include "stablehlo/conversions/linalg/transforms/Passes.h" + #include "stablehlo/transforms/Passes.h" ++#include "stablehlo/transforms/optimization/Passes.h" + #endif + + #ifdef TORCH_MLIR_ENABLE_TOSA From 7307e75ecdd044321c6fc9c8ba36abb16dd2490e Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Fri, 16 May 2025 11:55:48 -0700 Subject: [PATCH 06/10] debug --- .github/workflows/bazelBuildAndTestTcp.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/bazelBuildAndTestTcp.yml b/.github/workflows/bazelBuildAndTestTcp.yml index c608ff3..9cde609 100644 --- a/.github/workflows/bazelBuildAndTestTcp.yml +++ b/.github/workflows/bazelBuildAndTestTcp.yml @@ -56,6 +56,9 @@ jobs: find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i if [ -n "$(git status --porcelain)" ]; then echo "Please run 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes." + echo "git reports the following changes: " + echo "$(git status --porcelain)" + echo "clang-format version: $(clang-format --version)" exit 1 fi From 7bace8b7f7d0a031f0ae547329b6e591f4ea047d Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Fri, 16 May 2025 15:55:56 -0400 Subject: [PATCH 07/10] revert formatting changes --- lib/Dialect/IR/TcpOps.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index fd01fbf..7dbd2bb 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -192,9 +192,9 @@ LogicalResult GatherOp::verify() { i == gatherDim)) { std::stringstream ss; ss << "indicies index " << i - << " expected to be less than or equal to input " - << " (" << indicesTensor.getShape()[i] - << " <= " << inputTensor.getShape()[i] << ")"; + << " expected to be less than or equal to input " << " (" + << indicesTensor.getShape()[i] << " <= " << inputTensor.getShape()[i] + << ")"; return emitOpError(ss.str()); } } @@ -293,8 +293,7 @@ ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, void BindSymbolicShapeOp::print(OpAsmPrinter &p) { p << " " << getOperand() << ", ["; llvm::interleaveComma(getShapeSymbols(), p); - p << "], " - << "affine_map<" << getShapeExpressions().getValue() << ">"; + p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">"; p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape_expressions"}); p << " : " << getOperand().getType(); From ad09232b61abd761dbdfd6be7fa8de8d9e3ae43a Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Fri, 16 May 2025 12:56:33 -0700 Subject: [PATCH 08/10] more debugging --- .github/workflows/bazelBuildAndTestTcp.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazelBuildAndTestTcp.yml b/.github/workflows/bazelBuildAndTestTcp.yml index 9cde609..b4f21cc 100644 --- a/.github/workflows/bazelBuildAndTestTcp.yml +++ b/.github/workflows/bazelBuildAndTestTcp.yml @@ -58,7 +58,9 @@ jobs: echo "Please run 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes." echo "git reports the following changes: " echo "$(git status --porcelain)" - echo "clang-format version: $(clang-format --version)" + echo "$(git diff -u)" + docker run --rm mlir-tcp:ci clang-format --version + docker run --rm mlir-tcp:ci uname -a exit 1 fi From c1fbe6864b0baf2d0fcb3097f9f1077f248a204f Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 19 May 2025 06:58:05 -0700 Subject: [PATCH 09/10] buildifier --- deps.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps.bzl b/deps.bzl index 960a560..e783d9a 100644 --- a/deps.bzl +++ b/deps.bzl @@ -49,7 +49,7 @@ def third_party_deps(): urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)], patches = [ "//third_party/patches:torch-mlir.1.patch", - ] + ], ) if use_local_stablehlo_repo(): From 7b0a118841457bce79c48c6367a3c7a706377a57 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Wed, 4 Jun 2025 12:52:15 -0700 Subject: [PATCH 10/10] lit test updates --- test/python_lit/fx_import/basic_test.py | 5 ++--- test/python_lit/fx_import/custom_op_test.py | 14 +++++++------- .../fx_import/symbolic_shape_expr_test.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/test/python_lit/fx_import/basic_test.py b/test/python_lit/fx_import/basic_test.py index 4ba5524..13338d3 100644 --- a/test/python_lit/fx_import/basic_test.py +++ b/test/python_lit/fx_import/basic_test.py @@ -24,10 +24,10 @@ def run(f): @run # CHECK-LABEL: test_import_frozen_exported_program # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> -# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[a:.+]] = torch.aten.rand{{.*}} -> !torch.vtensor<[1,4],f32> # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> # CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] # CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] # CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] # CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] @@ -35,7 +35,6 @@ def run(f): # # Validate dialect resources exist. # CHECK: dialect_resources: -# CHECK-DAG: torch_tensor_1_4_torch.float32 # CHECK-DAG: torch_tensor_3_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, diff --git a/test/python_lit/fx_import/custom_op_test.py b/test/python_lit/fx_import/custom_op_test.py index d4105c2..ada3034 100644 --- a/test/python_lit/fx_import/custom_op_test.py +++ b/test/python_lit/fx_import/custom_op_test.py @@ -26,15 +26,15 @@ def run(f): # CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, # CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, # CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { -# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int -# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int -# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int -# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int -# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[S0:.+]] = torch.symbolic_int "s{{.*}}" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s{{.*}}" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s{{.*}}" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s{{.*}}" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> -# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S3]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> -# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S1]], %[[S3]], %[[S0]], %[[S2]]], affine_map<()[s0, s1, s2, s3] -> (s2, s1 + s3 + s0 * 2, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> def test_tanh_sigmoid_cat_custom_op(): diff --git a/test/python_lit/fx_import/symbolic_shape_expr_test.py b/test/python_lit/fx_import/symbolic_shape_expr_test.py index fd207a8..d17282f 100644 --- a/test/python_lit/fx_import/symbolic_shape_expr_test.py +++ b/test/python_lit/fx_import/symbolic_shape_expr_test.py @@ -26,20 +26,20 @@ def run(f): # CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, # CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, # CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { -# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int -# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int -# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int -# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int -# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[S0:.+]] = torch.symbolic_int "s{{[0-9]+}}" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s{{[0-9]+}}" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s{{[0-9]+}}" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s{{[0-9]+}}" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> -# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S3]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> -# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> # CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list # CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> -# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S1]], %[[S3]], %[[S0]], %[[S2]]], affine_map<()[s0, s1, s2, s3] -> (s2, s1 + s3 + s0 * 2, 3)> : !torch.vtensor<[?,?,3],f32> # CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32> def test_tanh_sigmoid_cat(): class TanhSigmoidCat(nn.Module):