- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][vector] Add support for vector.multi_reduction and vector.shape_cast distribution. #154438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add support for vector.multi_reduction and vector.shape_cast distribution. #154438
Conversation
| @llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Charitha Saumya (charithaintc) ChangesThis PR adds support for
PR also include changes in Full diff: https://github.com/llvm/llvm-project/pull/154438.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index be0d28a91cba7..6410a895fc9ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,13 +15,19 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include <cstddef> #include <utility> using namespace mlir; @@ -977,44 +983,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern { /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public WarpDistributionPattern { - using Base::Base; + + WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); if (!operand) return failure(); - auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); unsigned int operandNumber = operand->getOperandNumber(); - auto castDistributedType = + VectorType sourceType = oldCastOp.getSourceVectorType(); + VectorType distributedResultType = cast<VectorType>(warpOp->getResultTypes()[operandNumber]); - VectorType castOriginalType = oldCastOp.getSourceVectorType(); - VectorType castResultType = castDistributedType; - - // We expect the distributed type to have a smaller rank than the original - // type. Prepend with size-one dimensions to make them the same. - unsigned castDistributedRank = castDistributedType.getRank(); - unsigned castOriginalRank = castOriginalType.getRank(); - if (castDistributedRank < castOriginalRank) { - SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); - llvm::append_range(shape, castDistributedType.getShape()); - castDistributedType = - VectorType::get(shape, castDistributedType.getElementType()); + VectorType distributedSourceType = sourceType; + bool isResultDistributed = distributedResultType.getNumElements() < + oldCastOp.getResultVectorType().getNumElements(); + + // If the result is not distributed, source distribted type is the same + // as the source type. If the result is distributed, we need to compute the + // distributed source type according to following rules: + // 1. If the source type is yielded from the warp op, we can use the + // matching warp result type as the distributed source type. + // 2. If the source type is not yielded from the warp op, we need + // to compute the distributed source type based on the distribution map + // and the warp size. + if (isResultDistributed) { + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast<gpu::YieldOp>( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = + llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == oldCastOp.getSource(); + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedSourceType = + cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(oldCastOp.getSource()); + distributedSourceType = + getDistributedType(sourceType, map, warpOp.getWarpSize()); + if (!distributedSourceType) + return rewriter.notifyMatchFailure( + oldCastOp, + "cannot compute distributed source type for shape cast"); + } } SmallVector<size_t> newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, + rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value newCast = vector::ShapeCastOp::create( - rewriter, oldCastOp.getLoc(), castResultType, + rewriter, oldCastOp.getLoc(), distributedResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Sink out vector.create_mask op feeding into a warp op yield. @@ -1996,6 +2033,107 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { + VectorMultiDimReductionDistribution(MLIRContext *context, + PatternBenefit benefit = 1) + : WarpDistributionPattern(context, benefit) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>); + if (!yieldOperand) + return failure(); + auto reductionOp = + cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + VectorType sourceType = reductionOp.getSourceVectorType(); + VectorType distributedResultType = + cast<VectorType>(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. + if (sourceType.getRank() != 2) + return rewriter.notifyMatchFailure(warpOp, + "Only 2D reductions are supported."); + ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims(); + // Only 1 reduction dimension supported. + if (reductionDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Only 1 reduction dimension is supported."); + + // Create a constant vector to store the result of the reduction per lane. + TypedAttr zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + + // Col reduction. + if (reductionDims[0] == 0) { + // Source vector must be distributable to lanes in the col dimension. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + // Compute source distributed type. + SmallVector<int64_t> shape(sourceType.getShape()); + shape[1] = shape[1] / warpOp.getWarpSize(); + auto sourceDistributedType = VectorType::get(shape, elementType); + + // Yield the source and acc vectors from the WarpOp. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, + {sourceDistributedType, distributedResultType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + int nCols = sourceDistributedType.getShape()[1]; + Value source = newWarpOp.getResult(newRetIndices[0]); + Value acc = newWarpOp.getResult(newRetIndices[1]); + // For each column owned by a lane, extract the column (of size nRows x + // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the + // result back to the result vector. + for (int i = 0; i < nCols; ++i) { + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), source, {0, i}, + {sourceDistributedType.getShape()[0], 1}, {1, 1}); + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({sourceDistributedType.getShape()[0]}, elementType), + col); + Value accCol = + vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + colReduce, result, i); + } + // Replace the warp op result with the new reduction op. + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + return success(); + } + // For row reductions, we simply rewrite the MultiReductionOp in terms of + // multiple ReductionOps. Actual distribution is done by the WarpOpReduction + // pattern. + rewriter.setInsertionPointAfter(reductionOp); + int nRows = sourceType.getShape()[0]; + // For each row of the source, extract the row vector, do a reduction and, + // insert the result back to the result. + for (int i = 0; i < nRows; ++i) { + Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getSource(), i); + Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value rowReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + rowReduce, result, i); + } + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + + return success(); + } +}; + } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -2017,15 +2155,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( PatternBenefit readBenefit) { patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); patterns - .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, - WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, - WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, - WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>( + .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract, + WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar, + WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice, + WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>( patterns.getContext(), benefit); patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, - benefit); + patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(), + distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 4d2c964a6df3c..bf70fbbd27244 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { return %r : f32 } +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce +// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<32x64xf32>) + %acc = "some_def"() : () -> (vector<64xf32>) + %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32> + gpu.yield %1 : vector<64xf32> + } + return %r : vector<2xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce +// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { + %zero = arith.constant dense<0.0> : vector<2xf32> + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<2x32xf32>) + %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32> + gpu.yield %1 : vector<2xf32> + } + return %r : vector<2xf32> +} + // ----- // CHECK-PROP-LABEL: func @warp_duplicate_yield( @@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) // CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> // CHECK-PROP: return %[[CAST]] : vector<4xf32> +// ----- +func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32> + %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32> + gpu.yield %3 : vector<32x64xf32> + } + return %r : vector<32x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<32x2xf32> + +// ----- +func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) { + %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32> + %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32> + gpu.yield %3 : vector<8x4x2xf32> + } + return %r : vector<8x4x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32> + // ----- func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> { |
| cc @Garra1980 |
adam-smnk left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General direction looks good to me.
I'll let others have a closer look at the distribution logic.
Hi @adam-smnk, Thanks for the reviews. I addressed the concerns. Please let me know if you have any additional concerns. I also properly documented the restrictions in this version. |
| // distributed source type according to following rules: | ||
| // 1. If the source type is yielded from the warp op, we can use the | ||
| // matching warp result type as the distributed source type. | ||
| // 2. If the source type is not yielded from the warp op, we need |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give some example. Will it happen that there are conflicts between these two rules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added two examples for the 2 cases as comment.
For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic.
Given this layout,
Col reduction is easy : just reduce your own data.
Row reduction: needs to shuffle data with neighbors and do a tree like reduce (aka butterfly reduction with shuffles)
however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim). I am working on some proposal to improve this.
| reductionOp.getSource(), i); | ||
| Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), | ||
| reductionOp.getAcc(), i); | ||
| Value rowReduce = vector::ReductionOp::create( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious how is ShuffleOp inserted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is lowered progressively. Here we lower it to bunch of vector.reduction ops. Then WarpOpReduction pattern kicks in and do the actual distribution to shuffle ops.
WarpOpReduction is free to use any reduction strategy (specified by distributedReductionFn). Currently it by default use the one defined here.
https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp#L566
chencha3 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, it looks good to me. It might be helpful to have someone familiar with the distribution mechanism review it for accuracy.
| @chencha3 @adam-smnk I have addressed the comments. Please let me know if any other concerns. If not please consider approving the PR. |
| @akroviakov Maybe you could double-check the distribution logic? |
adam-smnk left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall seems fine. Minor comments.
@chencha3 I'm not sure about shuffles, if you think they're fine then PR should be good to go
| | ||
| // CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d | ||
| // CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32> | ||
| // CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why it is not
%[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64x1xf32> to vector<32x2xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the inner dim of the transfer_read result is distributed in this case. so it becomes 64x32 -> 2x32 after distribution.
| refactored the implementation in #157560. |
This PR adds support for
vector.multi_reductiondistribution. Currently only 2D to 1D reductions are supported (col/row reductions) and assumes the inner dimension of source vector is distributed among lanes (each lane owns columns of the source vector).vector.reductionof the column data.multi_reductionin terms ofreductionops.PR also include changes in
vector.shape_castdistribution to consider the distributed type of shape case source if given by aDistributionMapFn