Skip to content

Conversation

@fhahn
Copy link
Contributor

@fhahn fhahn commented Sep 9, 2025

During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.

During CSE, we don't have to drop poison-generating flags, if both the re-used recipe and the to-be-replaced recipe have the same flags.
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-powerpc

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.


Full diff: https://github.com/llvm/llvm-project/pull/157664.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+36)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll (+1-1)
  • (modified) llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll (+4-4)
  • (modified) llvm/test/Transforms/LoopVectorize/flags.ll (+1-1)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index b93bdf244237e..53291a931530f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -721,6 +721,10 @@ class VPIRFlags { AllFlags = Other.AllFlags; } + /// Only keep flags also present in \p Other. \p Other must have the same + /// OpType as the current object. + void intersectFlags(const VPIRFlags &Other); + /// Drop all poison-generating flags. void dropPoisonGeneratingFlags() { // NOTE: This needs to be kept in-sync with diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 46162a9276469..9f1311fbd0687 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent, } #endif +void VPIRFlags::intersectFlags(const VPIRFlags &Other) { + assert(OpType == Other.OpType && "OpType must match"); + switch (OpType) { + case OperationType::OverflowingBinOp: + WrapFlags.HasNUW &= Other.WrapFlags.HasNUW; + WrapFlags.HasNSW &= Other.WrapFlags.HasNSW; + break; + case OperationType::Trunc: + TruncFlags.HasNUW &= Other.TruncFlags.HasNUW; + TruncFlags.HasNSW &= Other.TruncFlags.HasNSW; + break; + case OperationType::DisjointOp: + DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint; + break; + case OperationType::PossiblyExactOp: + ExactFlags.IsExact = Other.ExactFlags.IsExact; + break; + case OperationType::GEPOp: + GEPFlags &= Other.GEPFlags; + break; + case OperationType::FPMathOp: + FMFs.NoNaNs &= Other.FMFs.NoNaNs; + FMFs.NoInfs &= Other.FMFs.NoInfs; + break; + case OperationType::NonNegOp: + NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg; + break; + case OperationType::Cmp: + assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate"); + break; + case OperationType::Other: + assert(AllFlags == Other.AllFlags && "Cannot drop other flags"); + break; + } +} + FastMathFlags VPIRFlags::getFastMathFlags() const { assert(OpType == OperationType::FPMathOp && "recipe doesn't have fast math flags"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 10b2f5df2e23e..d86b53dd894fb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) { // V must dominate Def for a valid replacement. if (!VPDT.dominates(V->getParent(), VPBB)) continue; - // Drop poison-generating flags when reusing a value. + // Only keep flags present on both V and Def. if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V)) - RFlags->dropPoisonGeneratingFlags(); + RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def)); Def->replaceAllUsesWith(V); continue; } diff --git a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll index 36c3a2a612d82..db1f2c71e0f77 100644 --- a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll +++ b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll @@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) { ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[INDEX]] to i64 -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4 ; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]]) ; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4 diff --git a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll index df54411f7e710..c2dfce0aa70b8 100644 --- a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll @@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 { ; CHECK-NEXT: [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]] ; CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0 ; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]]) -; CHECK-NEXT: [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1) -; CHECK-NEXT: [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]] +; CHECK-NEXT: [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1) +; CHECK-NEXT: [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]] ; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0 ; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]]) ; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]]) @@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 { ; CHECK-NEXT: [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]] ; CHECK-NEXT: [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0 ; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]]) -; CHECK-NEXT: [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1) -; CHECK-NEXT: [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]] +; CHECK-NEXT: [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1) +; CHECK-NEXT: [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]] ; CHECK-NEXT: [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0 ; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]]) ; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]]) diff --git a/llvm/test/Transforms/LoopVectorize/flags.ll b/llvm/test/Transforms/LoopVectorize/flags.ll index cef8ea656afaa..cbdcd50476b98 100644 --- a/llvm/test/Transforms/LoopVectorize/flags.ll +++ b/llvm/test/Transforms/LoopVectorize/flags.ll @@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4 ; CHECK-NEXT: store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 
Copy link
Contributor

@artagnon artagnon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo one error, thanks!

DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
break;
case OperationType::PossiblyExactOp:
ExactFlags.IsExact = Other.ExactFlags.IsExact;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ExactFlags.IsExact = Other.ExactFlags.IsExact;
ExactFlags.IsExact &= Other.ExactFlags.IsExact;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed, thanks. Also added a few missing tests

@fhahn fhahn enabled auto-merge (squash) September 10, 2025 09:49
@fhahn fhahn merged commit c3e76b2 into llvm:main Sep 10, 2025
9 checks passed
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 10, 2025
During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes. PR: llvm/llvm-project#157664
@fhahn fhahn deleted the vplan-cse-retain-matching-flags branch September 10, 2025 12:12
@jyknight
Copy link
Member

This causes an assertion failure on the following (llvm-reduce'd) test-case, when run as opt -passes='loop-vectorize<no-interleave-forced-only;no-vectorize-forced-only;>'.

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" target triple = "x86_64-linux-gnu" define ptr @test(ptr %0) { br label %loop loop: %2 = phi i64 [ 0, %1 ], [ %12, %loop ] %3 = zext i1 false to i64 %4 = load i8, ptr %0, align 1 %5 = and i8 %4, 1 %6 = trunc i8 %4 to i1 %7 = select i1 %6, float 1.000000e+00, float 0.000000e+00 %8 = trunc i8 %5 to i1 %9 = select i1 %8, float 0.000000e+00, float %7 %10 = bitcast float %9 to i32 %11 = trunc i32 %10 to i8 store i8 %11, ptr null, align 1 %12 = add i64 %2, 1 %exitcond.not = icmp eq i64 %2, 1 br i1 %exitcond.not, label %exit, label %loop exit: ret ptr null } 

With error:

opt: llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp:396: void llvm::VPIRFlags::intersectFlags(const VPIRFlags &): Assertion `OpType == Other.OpType && "OpType must match"' failed. 

(where OpType is "Other", and Other.OpType is "Trunc").

@fhahn
Copy link
Contributor Author

fhahn commented Sep 24, 2025

thanks for the heads up, taking a look

@fhahn
Copy link
Contributor Author

fhahn commented Sep 25, 2025

@jyknight should be fixed as part of the fix for #160396

@jyknight
Copy link
Member

@jyknight should be fixed as part of the fix for #160396

Thanks! Verified the fix on the original code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment