Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 13, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/167970.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+55-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+13)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+20)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0a9ef0aa6df96..afab880d173c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp } }; +/// Pattern for lowering vector.create_mask and vector.constant_mask ops to +/// subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResult().getType(); + ArrayRef<int64_t> wgShape = resultType.getShape(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + + SmallVector<Value> newMaskOps; + for (int i = 0; i < count; ++i) { + Value newMaskOp; + if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + newMaskOp = vector::CreateMaskOp::create( + rewriter, op.getLoc(), newResultType, op.getOperands()); + } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + newMaskOp = vector::ConstantMaskOp::create( + rewriter, op.getLoc(), newResultType, op.getMaskDimSizes()); + } else { + return rewriter.notifyMatchFailure(op, + "Unsupported mask operation type"); + } + xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp), + layout.dropSgLayoutAndData()); + + newMaskOps.push_back(newMaskOp); + } + + rewriter.replaceOpWithMultiple(op, {newMaskOps}); + return success(); + } +}; + +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; + } // namespace namespace mlir { @@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( patterns.getContext()); } } // namespace xegpu @@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, vector::BroadcastOp, - vector::MultiDimReductionOp>( + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 84ce80f477a55..b587ecc726f4d 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -130,5 +130,18 @@ gpu.module @test_distribution { %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32> gpu.return } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + %cst16 = arith.constant 16 : index + // CHECK: %[[CST16:.*]] = arith.constant 16 : index + // CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + // CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1> + // CHECK-NOT: vector.constant_mask + %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 4fbb566cfbe73..f254b82c6401f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -547,4 +547,24 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex> gpu.return } + + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + %cst8 = arith.constant 8 : index + // CHECK: vector.create_mask {{.*}} : vector<16xi1> + %create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1> + // CHECK: vector.constant_mask [8] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + %cst16 = arith.constant 16 : index + // CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1> + %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + // CHECK: vector.constant_mask [16, 16] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } } 
@nbpatel nbpatel changed the title [XeGPU][WgToSg] Add distribution for vector mask operations [MLIR][XeGPU][WgToSg] Add distribution for vector mask operations Nov 13, 2025
@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 13, 2025

actually it will be complex if mask dim sizes are bigger than sgData...please hold on reviewing this one till I update it

@nbpatel nbpatel closed this Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants