Skip to content

Conversation

@akroviakov
Copy link
Contributor

The input vector of a broadcast operation has a lower rank than the broadcast result. In xegpu terms, this means that the input data is sliced (in the unit dimension).
Currently, the broadcast simply passes the result layout to the operand, which is incorrect.
This PR wraps the result layout in a slice attribute.

For shape cast changes, I assume

 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1; 

in the propagation code implicitly considers only Nx1 or 1xN kind of shape cast results and does not need further slicing.

@llvmbot
Copy link
Member

llvmbot commented Nov 21, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

The input vector of a broadcast operation has a lower rank than the broadcast result. In xegpu terms, this means that the input data is sliced (in the unit dimension).
Currently, the broadcast simply passes the result layout to the operand, which is incorrect.
This PR wraps the result layout in a slice attribute.

For shape cast changes, I assume

 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1; 

in the propagation code implicitly considers only Nx1 or 1xN kind of shape cast results and does not need further slicing.


Full diff: https://github.com/llvm/llvm-project/pull/169054.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+18-5)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+27-2)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index b3a780abd3f12..2d8b5150d96fc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -572,8 +572,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( "one broadcasted dimension."); return; } + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), {broadcastUnitDims[0]})); // Propagate the result layout to the source operand. - propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); } void LayoutInfoPropagation::visitShapeCastOp( @@ -593,10 +597,19 @@ void LayoutInfoPropagation::visitShapeCastOp( return; } int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1; - xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( - shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()), - DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim})); - propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); + LayoutInfo operandLayout; + if (auto sliceResultAttr = dyn_cast<xegpu::SliceAttr>(resultLayout.get())) { + auto sliceDims = sliceResultAttr.getDims().asArrayRef(); + if (sliceDims.size() == 1 && sliceDims[0] == slicedDim) + operandLayout = resultLayout; + } else { + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + shapeCast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim})); + operandLayout = LayoutInfo(sliceLayout); + } + propagateIfChanged(operands[0], operands[0]->meet(operandLayout)); } /// Propagate the layout of the result tensor to the source tensor descriptor diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index eb004932af4be..58ccb90f0bdb1 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -603,7 +603,7 @@ gpu.module @test { // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> -// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : // CHECK-SAME: vector<16xf16> to vector<1x16xf16> func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { %c0 = arith.constant 0 : index @@ -626,7 +626,7 @@ gpu.module @test { // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] // CHECK-SAME: vector<16x16xf16> to vector<16xf16> -// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} : // CHECK-SAME: vector<16xf16> to vector<16x1xf16> func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { %c0 = arith.constant 0 : index @@ -639,3 +639,28 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc return } } + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_slice_operand( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: i64) { +// CHECK: %[[CST_0_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<1xindex> +// CHECK: %[[CST_TRUE_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<1xi1> +// CHECK: %[[CST_TRUE_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<32xi1> +// CHECK: %[[CST_0_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<32xindex> +// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%[[CST_0_1]]], %[[CST_TRUE_1]] +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>} : +// CHECK-SAME: i64, vector<1xindex>, vector<1xi1> -> vector<1xf32> +// CHECK: %[[BCASTED:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<1xf32> to vector<32xf32> +// CHECK: xegpu.store %[[BCASTED]], %[[ARG0]][%[[CST_0_32]]], %[[CST_TRUE_32]] : vector<32xf32>, i64, vector<32xindex>, vector<32xi1> +func.func @vector_broadcast_slice_operand(%arg0: i64) { + %offsets = arith.constant dense<0> : vector<1xindex> + %cst_4 = arith.constant dense<1> : vector<1xi1> + %cst_2 = arith.constant dense<1> : vector<32xi1> + %offsets_1 = arith.constant dense<0> : vector<32xindex> + %1 = xegpu.load %arg0[%offsets], %cst_4 : i64, vector<1xindex>, vector<1xi1> -> vector<1xf32> + %2 = vector.broadcast %1 : vector<1xf32> to vector<32xf32> + xegpu.store %2, %arg0[%offsets_1], %cst_2 : vector<32xf32>, i64, vector<32xindex>, vector<32xi1> + return +} +} 
@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7114 tests passed
  • 594 tests skipped
shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
LayoutInfo operandLayout;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add short comment here

if (auto sliceResultAttr = dyn_cast<xegpu::SliceAttr>(resultLayout.get())) {
auto sliceDims = sliceResultAttr.getDims().asArrayRef();
if (sliceDims.size() == 1 && sliceDims[0] == slicedDim)
operandLayout = resultLayout;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert on else condition?

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need additional canonicalization step to make broadcast nd -> nd. I am not sure the change in this PR is necessary. @Jianhui-Li what do you think?

// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK I am not sure this is what we expect. slice is used to convey that along some dimension the data is shared. so the layout can simply drop those shared dims.

In this case shape cast result can not be a slice layout. because it is 2D. Only input to the shape cast is a slice (result of reduction).

%cst_2 = arith.constant dense<1> : vector<32xi1>
%offsets_1 = arith.constant dense<0> : vector<32xindex>
%1 = xegpu.load %arg0[%offsets], %cst_4 : i64, vector<1xindex>, vector<1xi1> -> vector<1xf32>
%2 = vector.broadcast %1 : vector<1xf32> to vector<32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also I am not sure if %1 should have a slice attribute (becuuse rank is still 1).

We use slice in rank altering operations (reduction).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants