diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 76e837404c6b5..3117721a94152 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -34,24 +34,25 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result, LogicalResult FromTensorOp::verify() { ArrayRef tensorShape = getInput().getType().getShape(); RingAttr ring = getOutput().getType().getRing(); - unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree(); - bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree; - if (!compatible) { - InFlightDiagnostic diag = emitOpError() - << "input type " << getInput().getType() - << " does not match output type " - << getOutput().getType(); - diag.attachNote() << "the input type must be a tensor of shape [d] where d " - "is at most the degree of the polynomialModulus of " - "the output type's ring attribute"; - return diag; + IntPolynomialAttr polyMod = ring.getPolynomialModulus(); + if (polyMod) { + unsigned polyDegree = polyMod.getPolynomial().getDegree(); + bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree; + if (!compatible) { + InFlightDiagnostic diag = emitOpError() + << "input type " << getInput().getType() + << " does not match output type " + << getOutput().getType(); + diag.attachNote() + << "the input type must be a tensor of shape [d] where d " + "is at most the degree of the polynomialModulus of " + "the output type's ring attribute"; + return diag; + } } - APInt coefficientModulus = ring.getCoefficientModulus().getValue(); - unsigned cmodBitWidth = coefficientModulus.ceilLogBase2(); unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth(); - - if (inputBitWidth > cmodBitWidth) { + if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) { InFlightDiagnostic diag = emitOpError() << "input tensor element type " << getInput().getType().getElementType() @@ -67,24 +68,27 @@ LogicalResult FromTensorOp::verify() { LogicalResult ToTensorOp::verify() { ArrayRef tensorShape = getOutput().getType().getShape(); - unsigned polyDegree = getInput() - .getType() - .getRing() - .getPolynomialModulus() - .getPolynomial() - .getDegree(); - bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; + IntPolynomialAttr polyMod = + getInput().getType().getRing().getPolynomialModulus(); + if (polyMod) { + unsigned polyDegree = polyMod.getPolynomial().getDegree(); + bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; - if (compatible) - return success(); + if (compatible) + return success(); + + InFlightDiagnostic diag = emitOpError() + << "input type " << getInput().getType() + << " does not match output type " + << getOutput().getType(); + diag.attachNote() + << "the output type must be a tensor of shape [d] where d " + "is at most the degree of the polynomialModulus of " + "the input type's ring attribute"; + return diag; + } - InFlightDiagnostic diag = - emitOpError() << "input type " << getInput().getType() - << " does not match output type " << getOutput().getType(); - diag.attachNote() << "the output type must be a tensor of shape [d] where d " - "is at most the degree of the polynomialModulus of " - "the input type's ring attribute"; - return diag; + return success(); } LogicalResult MulScalarOp::verify() { diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir index f22b14897e98a..4937e17027afa 100644 --- a/mlir/test/Dialect/Polynomial/ops_errors.mlir +++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt --split-input-file --verify-diagnostics %s #my_poly = #polynomial.int_polynomial<1 + x**1024> -#ring = #polynomial.ring +#ring = #polynomial.ring !ty = !polynomial.polynomial func.func @test_from_tensor_too_large_coeffs() {