Skip to content

Conversation

@akroviakov
Copy link
Contributor

@akroviakov akroviakov commented Oct 2, 2025

This PR improves the warp distribution robustness by:

  1. Ensuring that during the warp result deduplication, results with no uses are not mapped to a non-existing index. Currently we map to newResultTypes.size(), but may opt out of inserting to it, leading to a later OOB error.
  2. Simplifying the scf.if and scf.for handling through the usage of moveRegionToNewWarpOpAndAppendReturns, which also performs warp result deduplication in-place. This allows avoiding cases where, for example, after sinking two scf.if that need the same escaping value, a higher-ranked sink-pattern tries to sink the escaping value producer (which is yielded twice at this point) prior to WarpOpDeadResult actually deduplicates the result, leading to sinking the same op twice (once per yield operand).
@akroviakov
Copy link
Contributor Author

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR improves the warp distribution robustness by:

  1. Ensuring that during the warp result deduplication, results with no uses are not mapped to a non-existing index. Currently we map to newResultTypes.size(), but may opt out of inserting to it, leading to a later OOB error.
  2. Simplifying the scf.if and scf.for handling through the usage of moveRegionToNewWarpOpAndAppendReturns, which also performs warp result deduplication in-place. This allows avoiding cases where, for example, after sinking two scf.if that need the same escaping value, a higher-ranked sink-pattern tries to lower the escaping value producer (which is yielded twice at this point) prior to WarpOpDeadResult actually deduplicates the result, leading to sinking the same op twice (once per yield operand).

Full diff: https://github.com/llvm/llvm-project/pull/161647.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+20-39)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+19)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f7d18be..47aa1ca40fb03 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -934,11 +934,13 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1845,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1872,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1890,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1938,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2065,19 +2058,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -2086,7 +2076,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); @@ -2110,7 +2100,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2136,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bb7639204022f..401cdd29b281c 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) { // CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> () // CHECK-PROP: return // CHECK-PROP: } + +// ----- +func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) { + %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] -> + (vector<1xf32>, vector<2xf32>, vector<1xf32>) { + %2 = "some_def"() : () -> (vector<32xf32>) + %3 = "some_def"() : () -> (vector<64xf32>) + gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32> + } + %r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>) + return %r0 : vector<1xf32> +} + +// CHECK-PROP: func @dedup_unused_result +// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>) +// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32> +// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32> +// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32> 
@akroviakov akroviakov changed the title MLIR][Vector] Improve warp distribution robustness [MLIR][Vector] Improve warp distribution robustness Oct 6, 2025
@akroviakov
Copy link
Contributor Author

@charithaintc pinging

@charithaintc
Copy link
Contributor

@charithaintc pinging

sorry. I will take a look today.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows avoiding cases where, for example, after sinking two scf.if that need the same escaping value, a higher-ranked sink-pattern tries to sink the escaping value producer (which is yielded twice at this point) prior to WarpOpDeadResult actually deduplicates the result, leading to sinking the same op twice (once per yield operand).

In this case, why can't the higher rank pattern check if its result is duplicated in the yield list. if so it will not sink and allow WarpOpDeadResult to deal with it first.

// values and types.
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the exisiting code is still correct. all other patterns just append the operands so should be scf.for and scf.if.

I would try to handle this issue at the high priority op that gets duplicated. We can avoid sinking if is duplicated multiple times and wait for deduplication to hit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't the higher rank pattern check if its result is duplicated in the yield list

This would actually make a good rule for all patterns in general. As for this particular PR, there is no reason to knowingly allow duplicated values when they can be avoided using existing utilities.

all other patterns just append the operands so should be scf.for and scf.if

All other distribution patterns use moveRegionToNewWarpOpAndAppendReturns (which does deduplication), not moveRegionToNewWarpOpAndReplaceReturns which simply sets the result types. So in this PR, we actually align with other patterns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this PR, we actually align with other patterns.

Got it. I missed it.

@akroviakov akroviakov force-pushed the akroviak/robust-vector-dist branch from 04ee917 to 7875f48 Compare October 14, 2025 10:48
@charithaintc charithaintc self-requested a review October 14, 2025 20:28
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

// values and types.
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this PR, we actually align with other patterns.

Got it. I missed it.

@akroviakov akroviakov force-pushed the akroviak/robust-vector-dist branch from 7875f48 to e6355ab Compare October 15, 2025 08:29
@akroviakov akroviakov merged commit 0a71fd1 into llvm:main Oct 15, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment