- Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][XeGPU] Wrap layout with a slice attr when propagating broadcast #169054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR][XeGPU] Wrap layout with a slice attr when propagating broadcast #169054
Conversation
| @llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesThe input vector of a broadcast operation has a lower rank than the broadcast result. In xegpu terms, this means that the input data is sliced (in the unit dimension). For shape cast changes, I assume in the propagation code implicitly considers only Full diff: https://github.com/llvm/llvm-project/pull/169054.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index b3a780abd3f12..2d8b5150d96fc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -572,8 +572,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( "one broadcasted dimension."); return; } + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), {broadcastUnitDims[0]})); // Propagate the result layout to the source operand. - propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); } void LayoutInfoPropagation::visitShapeCastOp( @@ -593,10 +597,19 @@ void LayoutInfoPropagation::visitShapeCastOp( return; } int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1; - xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( - shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()), - DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim})); - propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); + LayoutInfo operandLayout; + if (auto sliceResultAttr = dyn_cast<xegpu::SliceAttr>(resultLayout.get())) { + auto sliceDims = sliceResultAttr.getDims().asArrayRef(); + if (sliceDims.size() == 1 && sliceDims[0] == slicedDim) + operandLayout = resultLayout; + } else { + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + shapeCast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim})); + operandLayout = LayoutInfo(sliceLayout); + } + propagateIfChanged(operands[0], operands[0]->meet(operandLayout)); } /// Propagate the layout of the result tensor to the source tensor descriptor diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index eb004932af4be..58ccb90f0bdb1 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -603,7 +603,7 @@ gpu.module @test { // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> -// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : // CHECK-SAME: vector<16xf16> to vector<1x16xf16> func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { %c0 = arith.constant 0 : index @@ -626,7 +626,7 @@ gpu.module @test { // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] // CHECK-SAME: vector<16x16xf16> to vector<16xf16> -// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} : // CHECK-SAME: vector<16xf16> to vector<16x1xf16> func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { %c0 = arith.constant 0 : index @@ -639,3 +639,28 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc return } } + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_slice_operand( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: i64) { +// CHECK: %[[CST_0_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<1xindex> +// CHECK: %[[CST_TRUE_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<1xi1> +// CHECK: %[[CST_TRUE_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<32xi1> +// CHECK: %[[CST_0_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<32xindex> +// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%[[CST_0_1]]], %[[CST_TRUE_1]] +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>} : +// CHECK-SAME: i64, vector<1xindex>, vector<1xi1> -> vector<1xf32> +// CHECK: %[[BCASTED:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<1xf32> to vector<32xf32> +// CHECK: xegpu.store %[[BCASTED]], %[[ARG0]][%[[CST_0_32]]], %[[CST_TRUE_32]] : vector<32xf32>, i64, vector<32xindex>, vector<32xi1> +func.func @vector_broadcast_slice_operand(%arg0: i64) { + %offsets = arith.constant dense<0> : vector<1xindex> + %cst_4 = arith.constant dense<1> : vector<1xi1> + %cst_2 = arith.constant dense<1> : vector<32xi1> + %offsets_1 = arith.constant dense<0> : vector<32xindex> + %1 = xegpu.load %arg0[%offsets], %cst_4 : i64, vector<1xindex>, vector<1xi1> -> vector<1xf32> + %2 = vector.broadcast %1 : vector<1xf32> to vector<32xf32> + xegpu.store %2, %arg0[%offsets_1], %cst_2 : vector<32xf32>, i64, vector<32xindex>, vector<32xi1> + return +} +} |
🐧 Linux x64 Test Results
|
| shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()), | ||
| DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim})); | ||
| propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); | ||
| LayoutInfo operandLayout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add short comment here
| if (auto sliceResultAttr = dyn_cast<xegpu::SliceAttr>(resultLayout.get())) { | ||
| auto sliceDims = sliceResultAttr.getDims().asArrayRef(); | ||
| if (sliceDims.size() == 1 && sliceDims[0] == slicedDim) | ||
| operandLayout = resultLayout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe assert on else condition?
charithaintc left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need additional canonicalization step to make broadcast nd -> nd. I am not sure the change in this PR is necessary. @Jianhui-Li what do you think?
| // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} | ||
| // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> | ||
| // CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} | ||
| // CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK I am not sure this is what we expect. slice is used to convey that along some dimension the data is shared. so the layout can simply drop those shared dims.
In this case shape cast result can not be a slice layout. because it is 2D. Only input to the shape cast is a slice (result of reduction).
| %cst_2 = arith.constant dense<1> : vector<32xi1> | ||
| %offsets_1 = arith.constant dense<0> : vector<32xindex> | ||
| %1 = xegpu.load %arg0[%offsets], %cst_4 : i64, vector<1xindex>, vector<1xi1> -> vector<1xf32> | ||
| %2 = vector.broadcast %1 : vector<1xf32> to vector<32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here also I am not sure if %1 should have a slice attribute (becuuse rank is still 1).
We use slice in rank altering operations (reduction).
The input vector of a broadcast operation has a lower rank than the broadcast result. In xegpu terms, this means that the input data is sliced (in the unit dimension).
Currently, the broadcast simply passes the result layout to the operand, which is incorrect.
This PR wraps the result layout in a slice attribute.
For shape cast changes, I assume
in the propagation code implicitly considers only
Nx1or1xNkind of shape cast results and does not need further slicing.