Skip to content

Conversation

@inbelic
Copy link
Contributor

@inbelic inbelic commented Dec 10, 2023

Add missing constant propogation folder for SNegate, [Logical]Not.

Implement additional folding when !(!x) for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

Add missing constant propogation folder for SNegate, [Logical]Not. Implement additional folding when !(!x) for all ops. This helps for readability of lowered code into SPIR-V. Part of work for llvm#70704
@inbelic inbelic force-pushed the inbelic/spirv-folding-negate-ops branch from 174035d to e2e231d Compare December 20, 2023 08:16
@inbelic inbelic marked this pull request as ready for review December 20, 2023 17:03
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Finn Plummer (inbelic)

Changes

Add missing constant propogation folder for SNegate, [Logical]Not.

Implement additional folding when !(!x) for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704


Full diff: https://github.com/llvm/llvm-project/pull/74992.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td (+1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+55)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+116)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 51124e141c6d46..22d5afcd773817 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -582,6 +582,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate", %3 = spirv.SNegate %2 : vector<4xi32> ``` }]; + + let hasFolder = 1; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td index b460c8e68aa0c6..38639a175ab4db 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -462,6 +462,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> { %3 = spirv.Not %1 : vector<4xi32> ``` }]; + + let hasFolder = 1; } #endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 47887ffb474f00..260d24b5502577 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -528,6 +528,7 @@ def SPIRV_LogicalNotOp : SPIRV_LogicalUnaryOp<"LogicalNot", }]; let hasCanonicalizer = 1; + let hasFolder = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 9de1707dfca465..fe334d50b6faaa 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -643,6 +643,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) { return div0 ? Attribute() : res; } +//===----------------------------------------------------------------------===// +// spirv.SNegate +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) { + // -(-x) = 0 - (0 - x) = x + auto op = getOperand(); + if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>()) + return negateOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Signed-integer subtract of Operand from zero. + return constFoldUnaryOp<IntegerAttr>( + adaptor.getOperands(), [](const APInt &a) { + APInt zero = APInt::getZero(a.getBitWidth()); + return zero - a; + }); +} + +//===----------------------------------------------------------------------===// +// spirv.NotOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) { + // !(!x) = x + auto op = getOperand(); + if (auto notOp = op.getDefiningOp<spirv::NotOp>()) + return notOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Complement the bits of Operand. + return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) { + a.flipAllBits(); + return a; + }); +} + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// @@ -681,6 +720,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) { // spirv.LogicalNot //===----------------------------------------------------------------------===// +OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) { + // !(!x) = x + auto op = getOperand(); + if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>()) + return notOp->getOperand(0); + + // According to the SPIR-V spec: + // + // Complement the bits of Operand. + return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), + [](const APInt &a) { + APInt zero = APInt::getZero(1); + return a == 1 ? zero : (zero + 1); + }); +} + void spirv::LogicalNotOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 29bea91ce461d9..7da2cf5be4e007 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -1006,6 +1006,90 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { // ----- +//===----------------------------------------------------------------------===// +// spirv.SNegate +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @snegate_twice +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @snegate_twice(%arg0 : i32) -> i32 { + %0 = spirv.SNegate %arg0 : i32 + %1 = spirv.SNegate %0 : i32 + + // CHECK: return %[[ARG]] : i32 + return %1 : i32 +} + +// CHECK-LABEL: @const_fold_scalar_snegate +func.func @const_fold_scalar_snegate() -> (i32, i32, i32) { + %c0 = spirv.Constant 0 : i32 + %c3 = spirv.Constant 3 : i32 + %cn3 = spirv.Constant -3 : i32 + + // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32 + // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32 + // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32 + %0 = spirv.SNegate %c0 : i32 + %1 = spirv.SNegate %c3 : i32 + %2 = spirv.SNegate %cn3 : i32 + + // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]] + return %0, %1, %2 : i32, i32, i32 +} + +// CHECK-LABEL: @const_fold_vector_snegate +func.func @const_fold_vector_snegate() -> vector<3xi32> { + // CHECK: spirv.Constant dense<[0, 3, -3]> + %cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32> + %0 = spirv.SNegate %cv : vector<3xi32> + return %0 : vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.Not +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @not_twice +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @not_twice(%arg0 : i32) -> i32 { + %0 = spirv.Not %arg0 : i32 + %1 = spirv.Not %0 : i32 + + // CHECK: return %[[ARG]] : i32 + return %1 : i32 +} + +// CHECK-LABEL: @const_fold_scalar_not +func.func @const_fold_scalar_not() -> (i32, i32, i32) { + %c0 = spirv.Constant 0 : i32 + %c3 = spirv.Constant 3 : i32 + %cn3 = spirv.Constant -3 : i32 + + // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32 + // CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32 + // CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32 + %0 = spirv.Not %c0 : i32 + %1 = spirv.Not %c3 : i32 + %2 = spirv.Not %cn3 : i32 + + // CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]] + return %0, %1, %2 : i32, i32, i32 +} + +// CHECK-LABEL: @const_fold_vector_not +func.func @const_fold_vector_not() -> vector<3xi32> { + %cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32> + + // CHECK: spirv.Constant dense<[0, 3, -3]> + %0 = spirv.Not %cv : vector<3xi32> + + return %0 : vector<3xi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// @@ -1040,6 +1124,38 @@ func.func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector< // spirv.LogicalNot //===----------------------------------------------------------------------===// +// CHECK-LABEL: @logical_not_twice +// CHECK-SAME: (%[[ARG:.*]]: i1) +func.func @logical_not_twice(%arg0 : i1) -> i1 { + %0 = spirv.LogicalNot %arg0 : i1 + %1 = spirv.LogicalNot %0 : i1 + + // CHECK: return %[[ARG]] : i1 + return %1 : i1 +} + +// CHECK-LABEL: @const_fold_scalar_logical_not +func.func @const_fold_scalar_logical_not() -> i1 { + %true = spirv.Constant true + + // CHECK: spirv.Constant false + %0 = spirv.LogicalNot %true : i1 + + return %0 : i1 +} + +// CHECK-LABEL: @const_fold_vector_logical_not +func.func @const_fold_vector_logical_not() -> vector<2xi1> { + %cv = spirv.Constant dense<[true, false]> : vector<2xi1> + + // CHECK: spirv.Constant dense<[false, true]> + %0 = spirv.LogicalNot %cv : vector<2xi1> + + return %0 : vector<2xi1> +} + +// ----- + func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { // CHECK: %[[RESULT:.*]] = spirv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> // CHECK-NEXT: spirv.ReturnValue %[[RESULT]] : vector<3xi1> 
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I think it would be worth adding one more test for though.

- add testcase to demonstrate SNegate behaviour for INT_MIN
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@inbelic inbelic merged commit 88151dd into llvm:main Dec 21, 2023
@inbelic inbelic deleted the inbelic/spirv-folding-negate-ops branch August 2, 2024 19:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants