diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 3ef899d3376b1..f99cbccd243ec 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> { // add two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> { // subtract two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> { // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [ // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> %1 = arith.constant 3 : i32 %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32 ``` @@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32) ``` }]; @@ -272,29 +272,29 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> { let hasVerifier = 1; } -def Polynomial_AnyPolynomialAttr : AnyAttrOf<[ - Polynomial_FloatPolynomialAttr, - Polynomial_IntPolynomialAttr +def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[ + Polynomial_TypedFloatPolynomialAttr, + Polynomial_TypedIntPolynomialAttr ]>; // Not deriving from Polynomial_Op due to need for custom assembly format -def Polynomial_ConstantOp : Op { +def Polynomial_ConstantOp : Op { let summary = "Define a constant polynomial via an attribute."; let description = [{ Example: ```mlir - #poly = #polynomial.int_polynomial - #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + !int_poly_ty = !polynomial.polynomial> + %0 = polynomial.constant int<1 + x**2> : !int_poly_ty - #float_ring = #polynomial.ring - %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> + !float_poly_ty = !polynomial.polynomial> + %1 = polynomial.constant float<0.5 + 1.3e06 x**2> : !float_poly_ty ``` }]; - let arguments = (ins Polynomial_AnyPolynomialAttr:$value); + let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value); let results = (outs Polynomial_PolynomialType:$output); - let assemblyFormat = "attr-dict `:` type($output)"; + let hasCustomAssemblyFormat = 1; } def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index e5dbfa7fa21ee..655020adf808b 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -18,7 +18,7 @@ class Polynomial_Attr traits = []> } def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with integer coefficients."; + let summary = "an attribute containing a single-variable polynomial with integer coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with integer coefficients, which is used to define the modulus of a `RingAttr`, as well @@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom } def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients."; + let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with double precision floating point coefficients. @@ -62,8 +62,72 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p let hasCustomAssemblyFormat = 1; } +def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< + "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> { + let summary = "a typed int_polynomial"; + let description = [{ + Example: + + ```mlir + !poly_ty = !polynomial.polynomial> + #poly = int<1 x**7 + 4> : !poly_ty + #poly_verbose = #polynomial.typed_int_polynomial<1 x**7 + 4> : !poly_ty + ``` + }]; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value); + let assemblyFormat = "$value `:` $type"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const IntPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + IntPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = ::mlir::Attribute; + }]; +} + +def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< + "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> { + let summary = "a typed float_polynomial"; + let description = [{ + Example: + + ```mlir + !poly_ty = !polynomial.polynomial> + #poly = float<1.4 x**7 + 4.5> : !poly_ty + #poly_verbose = #polynomial.typed_float_polynomial<1.4 x**7 + 4.5> : !poly_ty + ``` + }]; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value); + let assemblyFormat = "$value `:` $type"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const FloatPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + FloatPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = ::mlir::Attribute; + }]; +} + def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { - let summary = "An attribute specifying a polynomial ring."; + let summary = "an attribute specifying a polynomial ring"; let description = [{ A ring describes the domain in which polynomial arithmetic occurs. The ring attribute in `polynomial` represents the more specific case of polynomials diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index 890ce5226c30f..cc7d3172b1a1d 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -101,7 +101,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, return success(); } -template +template LogicalResult parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, llvm::StringSet<> &variables, @@ -155,7 +155,7 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { llvm::SmallVector monomials; llvm::StringSet<> variables; - if (failed(parsePolynomialAttr( + if (failed(parsePolynomialAttr( parser, monomials, variables, [&](IntMonomial &monomial) -> OptionalParseResult { APInt parsedCoeff(apintBitWidth, 1); @@ -175,7 +175,6 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { } return IntPolynomialAttr::get(parser.getContext(), result.value()); } - Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; @@ -191,8 +190,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { return OptionalParseResult(result); }; - if (failed(parsePolynomialAttr( - parser, monomials, variables, parseAndStoreCoefficient))) { + if (failed(parsePolynomialAttr(parser, monomials, variables, + parseAndStoreCoefficient))) { return {}; } diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 1a2439fe810b5..d0a25fd9288b9 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -186,6 +186,88 @@ LogicalResult INTTOp::verify() { return verifyNTTOp(this->getOperation(), ring, tensorType); } +ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { + // Using the built-in parser.parseAttribute requires the full + // #polynomial.typed_int_polynomial syntax, which is excessive. + // Instead we parse a keyword int to signal it's an integer polynomial + Type type; + if (succeeded(parser.parseOptionalKeyword("float"))) { + Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr); + if (floatPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + result.addAttribute("value", + TypedFloatPolynomialAttr::get(type, floatPolyAttr)); + result.addTypes(type); + return success(); + } + } + + if (succeeded(parser.parseOptionalKeyword("int"))) { + Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr); + if (intPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + + result.addAttribute("value", + TypedIntPolynomialAttr::get(type, intPolyAttr)); + result.addTypes(type); + return success(); + } + } + + // In the worst case, still accept the verbose versions. + TypedIntPolynomialAttr typedIntPolyAttr; + OptionalParseResult res = + parser.parseOptionalAttribute( + typedIntPolyAttr, "value", result.attributes); + if (res.has_value() && succeeded(res.value())) { + result.addTypes(typedIntPolyAttr.getType()); + return success(); + } + + TypedFloatPolynomialAttr typedFloatPolyAttr; + res = parser.parseAttribute( + typedFloatPolyAttr, "value", result.attributes); + if (res.has_value() && succeeded(res.value())) { + result.addTypes(typedFloatPolyAttr.getType()); + return success(); + } + + return failure(); +} + +void ConstantOp::print(OpAsmPrinter &p) { + p << " "; + if (auto intPoly = dyn_cast(getValue())) { + p << "int"; + intPoly.getValue().print(p); + } else if (auto floatPoly = dyn_cast(getValue())) { + p << "float"; + floatPoly.getValue().print(p); + } else { + assert(false && "unexpected attribute type"); + } + p << " : "; + p.printType(getOutput().getType()); +} + +LogicalResult ConstantOp::inferReturnTypes( + MLIRContext *context, std::optional location, + ConstantOp::Adaptor adaptor, + llvm::SmallVectorImpl &inferredReturnTypes) { + Attribute operand = adaptor.getValue(); + if (auto intPoly = dyn_cast(operand)) { + inferredReturnTypes.push_back(intPoly.getType()); + } else if (auto floatPoly = dyn_cast(operand)) { + inferredReturnTypes.push_back(floatPoly.getType()); + } else { + assert(false && "unexpected attribute type"); + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index ff709960c50e9..4716e37ff8852 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -74,15 +74,19 @@ module { func.func @test_monic_monomial_mul() { %five = arith.constant 5 : index - %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial, index) -> !polynomial.polynomial return } func.func @test_constant() { - %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial - %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial - %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial + %1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial + %2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial + + // Test verbose fallbacks + %verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial + %verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial return }