From fca9a88d78f6945d58c1ef8d6965942679028082 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 19 Nov 2024 04:35:21 +0100 Subject: [PATCH] [mlir][IR] Treat `tf32` as float with bitwidth 19 --- mlir/lib/IR/BuiltinTypes.cpp | 5 ----- mlir/test/IR/attribute.mlir | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 25e9f80c9963c..e8e8f3cdfbfd7 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -91,11 +91,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - // The actual width of TF32 is 19 bits. However, since it is a truncated - // version of Float32, we treat it as 32 bits in MLIR FloatType::getWidth - // for compatibility. - if (llvm::isa(*this)) - return 32; return APFloat::semanticsSizeInBits(getFloatSemantics()); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index a62de3f5004d7..0085d64ae82b6 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -561,6 +561,14 @@ func.func @correct_type_pass() { // ----- +func.func @tf32_elements_attr() { + // CHECK: "foo"() {attr = dense<4.000000e+00> : tensor} : () -> () + "foo"() {attr = dense<4.0> : tensor} : () -> () + return +} + +// ----- + //===----------------------------------------------------------------------===// // Test StringElementsAttr //===----------------------------------------------------------------------===// @@ -675,6 +683,14 @@ func.func @dense_array_attr() attributes { // ----- +func.func @test_invalid_bitwidth_type() { + // expected-error @below{{element type bitwidth must be a multiple of 8}} + "foo"() {tf32attr = array} : () -> () + return +} + +// ----- + func.func @testConfinedDenseArrayAttr() { "test.confined_dense_array_attr"() { i64attr = array,