Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 25, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/169571.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+29-20)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+8)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+37)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index beb9b60aa9d7a..95c20b1fabe58 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp } }; -// This pattern distributes the vector.constant_mask ops to work at subgroup -// level. -struct WgToSgVectorConstantMaskOp - : public OpConversionPattern<vector::ConstantMaskOp> { - using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { +// Distribute vector mask ops to work at subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) @@ -1288,9 +1288,16 @@ struct WgToSgVectorConstantMaskOp VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); - ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes(); + SmallVector<Value> wgMaskDimSizes; + if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } - // Get subgroup ID. Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); auto sgOffsets = @@ -1302,19 +1309,17 @@ struct WgToSgVectorConstantMaskOp VectorType resultType = VectorType::get(sgShape, type.getElementType()); // In each dimension, each subgroup computes its local mask size as: - // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d]) + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) SmallVector<Value> newCreateMaskOps; for (auto offsetSet : *sgOffsets) { SmallVector<Value> maskOperands; - for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) { - Value wgMaskSizeVal = - arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize); + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { Value dimSizeVal = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); Value offset = offsetSet[i]; Value adjustedMaskSize = - arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset); + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value nonNegative = arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); @@ -1335,6 +1340,8 @@ struct WgToSgVectorConstantMaskOp } }; +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; } // namespace namespace mlir { @@ -1350,7 +1357,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, - WgToSgVectorConstantMaskOp>(patterns.getContext()); + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1477,9 +1485,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp< - vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, - vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 1cddccb5fbbd1..4fb50b3b28534 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -138,5 +138,13 @@ gpu.module @test_distribution { %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> gpu.return } + + gpu.func @vector_create_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + }  } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 574b365443a0a..48e93320093fd 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -583,6 +583,43 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: vector_create_mask_1D + gpu.func @vector_create_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %cst8 = arith.constant 8 : index + %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_2D + gpu.func @vector_create_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + // CHECK-LABEL: distribute_load_slice_attr gpu.func @distribute_load_slice_attr() { %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants