- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][vector] Add support for vector.multi_reduction and vector.shape_cast distribution. #154438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eaaca7f 56c3441 8da99e4 01880b5 53da992 affd4aa df59c20 70a5a49 5579731 07c0364 4b031de 4ed74d8 66e105f 116e4bc File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -17,6 +17,7 @@ | |
| #include "mlir/IR/AffineExpr.h" | ||
| #include "mlir/IR/Attributes.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/Value.h" | ||
| #include "mlir/Interfaces/SideEffectInterfaces.h" | ||
| #include "mlir/Transforms/RegionUtils.h" | ||
| #include "llvm/ADT/SetVector.h" | ||
| | @@ -1020,44 +1021,75 @@ struct WarpOpBroadcast : public WarpDistributionPattern { | |
| /// Pattern to move shape cast out of the warp op. shape cast is basically a | ||
| /// no-op for warp distribution; we need to handle the shape though. | ||
| struct WarpOpShapeCast : public WarpDistributionPattern { | ||
| using Base::Base; | ||
| | ||
| WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) | ||
| : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} | ||
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, | ||
| PatternRewriter &rewriter) const override { | ||
| OpOperand *operand = | ||
| getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); | ||
| if (!operand) | ||
| return failure(); | ||
| | ||
| auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); | ||
| | ||
| unsigned int operandNumber = operand->getOperandNumber(); | ||
| auto castDistributedType = | ||
| unsigned operandNumber = operand->getOperandNumber(); | ||
| VectorType sourceType = oldCastOp.getSourceVectorType(); | ||
| VectorType distributedResultType = | ||
| cast<VectorType>(warpOp->getResultTypes()[operandNumber]); | ||
| VectorType castOriginalType = oldCastOp.getSourceVectorType(); | ||
| VectorType castResultType = castDistributedType; | ||
| | ||
| // We expect the distributed type to have a smaller rank than the original | ||
| // type. Prepend with size-one dimensions to make them the same. | ||
| unsigned castDistributedRank = castDistributedType.getRank(); | ||
| unsigned castOriginalRank = castOriginalType.getRank(); | ||
| if (castDistributedRank < castOriginalRank) { | ||
| SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); | ||
| llvm::append_range(shape, castDistributedType.getShape()); | ||
| castDistributedType = | ||
| VectorType::get(shape, castDistributedType.getElementType()); | ||
| VectorType distributedSourceType = sourceType; | ||
| bool isResultDistributed = distributedResultType.getNumElements() < | ||
| oldCastOp.getResultVectorType().getNumElements(); | ||
| | ||
| // If the result is not distributed, source distributed type is the same | ||
| // as the source type. If the result is distributed, we need to compute the | ||
| // distributed source type according to following rules: | ||
| // 1. If the source type is yielded from the warp op, we can use the | ||
| // matching warp result type as the distributed source type. | ||
| // 2. If the source type is not yielded from the warp op, we need | ||
| Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you give some example. Will it happen that there are conflicts between these two rules? Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added two examples for the 2 cases as comment. For both row and col reduction we assume that cols of the source vector is owned by each lane (i.e. in xegpu layouts this will be [1, 16]). Based on that we handle the reduction logic. however source layout can also be [16, 1]. This case is not supported because the vector distribution infra does not allow me to express such layout currently (it always start distributing the vector from innermost dim). I am working on some proposal to improve this. | ||
| // to compute the distributed source type based on the distribution map | ||
| // and the warp size. | ||
| if (isResultDistributed) { | ||
| // Check if the source is yielded from the warp op. | ||
| gpu::YieldOp yieldOp = cast<gpu::YieldOp>( | ||
| warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||
| OpOperand *it = | ||
| llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { | ||
| return operand.get() == oldCastOp.getSource(); | ||
| }); | ||
| | ||
| if (it != yieldOp->getOpOperands().end()) { | ||
| // If the source is yielded from the warp op, we can use the matching | ||
| // warp result type as the distributed source type. | ||
| distributedSourceType = | ||
| cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]); | ||
| } else { | ||
| // If the source is not yielded from the warp op, we need to compute | ||
| // the distributed source type based on the distribution map and the | ||
| // warp size. | ||
| AffineMap map = distributionMapFn(oldCastOp.getSource()); | ||
| distributedSourceType = | ||
| getDistributedType(sourceType, map, warpOp.getWarpSize()); | ||
| if (!distributedSourceType) | ||
| return rewriter.notifyMatchFailure( | ||
| oldCastOp, | ||
| "cannot compute distributed source type for shape cast"); | ||
| } | ||
| } | ||
| | ||
| SmallVector<size_t> newRetIndices; | ||
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||
| rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, | ||
| rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, | ||
| newRetIndices); | ||
| rewriter.setInsertionPointAfter(newWarpOp); | ||
| Value newCast = vector::ShapeCastOp::create( | ||
| rewriter, oldCastOp.getLoc(), castResultType, | ||
| rewriter, oldCastOp.getLoc(), distributedResultType, | ||
| newWarpOp->getResult(newRetIndices[0])); | ||
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); | ||
| return success(); | ||
| } | ||
| | ||
| private: | ||
| DistributionMapFn distributionMapFn; | ||
| }; | ||
| | ||
| /// Sink out vector.create_mask op feeding into a warp op yield. | ||
| | @@ -2038,6 +2070,180 @@ struct WarpOpReduction : public WarpDistributionPattern { | |
| DistributedReductionFn distributedReductionFn; | ||
| }; | ||
| | ||
| /// This patterns distribute the `vector.multi_reduction` operation across | ||
| /// lanes in a warp. Currently only 2D to 1D reductions are supported and | ||
| /// assumes that source vector is distributed in column dimension (i.e. Each | ||
| /// lane owns complete column(s) of the source vector). | ||
| /// TODO: Add support for the case where source rows are distributed across | ||
| /// lanes. Requires `DistributionMapFn` to express the data distribution. | ||
| /// Example 1 (Col reduction): | ||
| /// ``` | ||
| /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { | ||
| /// %0 = "some_def"() : () -> (vector<16x32xf32>) | ||
| /// %acc = "some_def"() : () -> (vector<32xf32>) | ||
| /// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to | ||
| /// vector<32xf32> gpu.yield %1 : vector<32xf32> | ||
| /// } | ||
| /// ``` | ||
| /// is lowered to: | ||
| /// ``` | ||
| /// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>, | ||
| /// vector<1xf32>) { | ||
| /// %0 = "some_def"() : () -> (vector<16x32xf32>) | ||
| /// %acc = "some_def"() : () -> (vector<32xf32>) | ||
| /// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32> | ||
| /// } | ||
| /// %c = arith.constant dense<0.0> : vector<1xf32> | ||
| /// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32> | ||
| /// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32 | ||
| /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> | ||
| /// ``` | ||
| /// Example 2 (Row reduction): | ||
| /// ``` | ||
| /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { | ||
| /// %0 = "some_def"() : () -> (vector<2x32xf32>) | ||
| /// %acc = "some_def"() : () -> (vector<2xf32>) | ||
| /// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to | ||
| /// vector<2xf32> | ||
| /// gpu.yield %1 : vector<2xf32> | ||
| /// } | ||
| /// ``` | ||
| /// is lowered to: | ||
| /// ``` | ||
| /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { | ||
| /// %0 = "some_def"() : () -> (vector<2x32xf32>) | ||
| /// %acc = "some_def"() : () -> (vector<2xf32>) | ||
| /// %1 = arith.constant dense<0.0> : vector<2xf32> | ||
| /// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>> | ||
| /// %3 = ("warp.reduction %2") : f32 | ||
| /// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32> | ||
| /// ... repeat for row 1 | ||
| /// gpu.yield %1 : vector<2xf32> | ||
| /// } | ||
| struct WarpOpMultiReduction : public WarpDistributionPattern { | ||
| using Base::Base; | ||
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, | ||
| PatternRewriter &rewriter) const override { | ||
| OpOperand *yieldOperand = | ||
| getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>); | ||
| if (!yieldOperand) | ||
| return failure(); | ||
| auto reductionOp = | ||
| cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp()); | ||
| unsigned operandNumber = yieldOperand->getOperandNumber(); | ||
| VectorType sourceType = reductionOp.getSourceVectorType(); | ||
| | ||
| // Only 2D vectors are supported. | ||
| if (sourceType.getRank() != 2) | ||
| return rewriter.notifyMatchFailure(warpOp, | ||
| "Only 2D reductions are supported."); | ||
| ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims(); | ||
| // Only 1 reduction dimension supported. This also ensures that result is | ||
| // also vector type. | ||
| if (reductionDims.size() != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| warpOp, "Only 1 reduction dimension is supported."); | ||
| int64_t reductionDim = reductionDims[0]; | ||
| auto resultType = cast<VectorType>(reductionOp.getType()); | ||
| auto distributedResultType = | ||
| cast<VectorType>(warpOp.getResult(operandNumber).getType()); | ||
| Type elementType = distributedResultType.getElementType(); | ||
| | ||
| // Currently we make the following assumptions. | ||
| // 1. The source vector is distributed in the column dimension. Each lane | ||
| // owns complete column(s) of the source vector. | ||
| // 2. If the reduction dim == 0, its a lane-local col reduction. In this | ||
| // case each lane owns its portion of the result (i.e. result is also | ||
| // distributed). | ||
| // 3. If reduction dim == 1, its a row reduction that require cross lanes | ||
| // shuffles. In this case, the reduction result is not distributed across | ||
| // lanes. Instead each lane owns a complete copy of the result | ||
| // (broadcasted). | ||
| // TODO: These assumptions are fairly restrictive. For example, source | ||
| // vector can have row distributed layout. Improve support for such cases. | ||
| if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) | ||
| return rewriter.notifyMatchFailure( | ||
| warpOp, "Source vector dimension must be divisible by warp size."); | ||
| bool isResultDistributed = | ||
| distributedResultType.getNumElements() < resultType.getNumElements(); | ||
| if (reductionDim == 0 && !isResultDistributed) | ||
| return rewriter.notifyMatchFailure( | ||
| warpOp, | ||
| "Expecting result vector to be distributed in a col reduction."); | ||
| if (reductionDim == 1 && isResultDistributed) | ||
| return rewriter.notifyMatchFailure( | ||
| warpOp, | ||
| "Expecting result vector to be broadcasted in a row reduction."); | ||
| | ||
| // Create a constant vector to store the result of the reduction per lane. | ||
| TypedAttr zeroAttr = | ||
| rewriter.getZeroAttr(distributedResultType.getElementType()); | ||
| Value result = arith::ConstantOp::create( | ||
| rewriter, reductionOp->getLoc(), distributedResultType, | ||
| DenseElementsAttr::get(distributedResultType, zeroAttr)); | ||
| // Col reduction. | ||
| if (reductionDim == 0) { | ||
| // Compute source distributed type assuming each lane owns cols. | ||
| SmallVector<int64_t> shape(sourceType.getShape()); | ||
| shape[1] = shape[1] / warpOp.getWarpSize(); | ||
| auto sourceDistributedType = VectorType::get(shape, elementType); | ||
| | ||
| // Yield the source and acc vectors from the WarpOp. | ||
| SmallVector<size_t> newRetIndices; | ||
| auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||
| rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, | ||
| {sourceDistributedType, distributedResultType}, newRetIndices); | ||
| rewriter.setInsertionPointAfter(newWarpOp); | ||
| | ||
| int nCols = sourceDistributedType.getShape()[1]; | ||
| Value source = newWarpOp.getResult(newRetIndices[0]); | ||
| Value acc = newWarpOp.getResult(newRetIndices[1]); | ||
| // For each column owned by a lane, extract the column (of size nRows x | ||
| // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the | ||
| // result back to the result vector. | ||
| for (int i = 0; i < nCols; ++i) { | ||
| Value col = vector::ExtractStridedSliceOp::create( | ||
| rewriter, reductionOp.getLoc(), source, {0, i}, | ||
| {sourceDistributedType.getShape()[0], 1}, {1, 1}); | ||
| col = vector::ShapeCastOp::create( | ||
| rewriter, reductionOp.getLoc(), | ||
| VectorType::get({sourceDistributedType.getShape()[0]}, elementType), | ||
| col); | ||
| Value accCol = | ||
| vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); | ||
| Value colReduce = vector::ReductionOp::create( | ||
| rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); | ||
| result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), | ||
| colReduce, result, i); | ||
| } | ||
| // Replace the warp op result with the new reduction op. | ||
| rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); | ||
| return success(); | ||
| } | ||
| // For row reductions, we simply rewrite the MultiReductionOp in terms of | ||
| // multiple ReductionOps. Actual distribution is done by the WarpOpReduction | ||
| // pattern. | ||
| rewriter.setInsertionPointAfter(reductionOp); | ||
| int nRows = sourceType.getShape()[0]; | ||
| // For each row of the source, extract the row vector, do a reduction and, | ||
| // insert the result back to the result. | ||
| for (int i = 0; i < nRows; ++i) { | ||
| Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), | ||
| reductionOp.getSource(), i); | ||
| Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), | ||
| reductionOp.getAcc(), i); | ||
| Value rowReduce = vector::ReductionOp::create( | ||
| Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious how is ShuffleOp inserted? Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is lowered progressively. Here we lower it to bunch of
| ||
| rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); | ||
| result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), | ||
| rowReduce, result, i); | ||
| } | ||
| // Replace the warp op result with the final result. | ||
| rewriter.replaceAllUsesWith(reductionOp.getResult(), result); | ||
| | ||
| return success(); | ||
| } | ||
| }; | ||
| | ||
| } // namespace | ||
| | ||
| void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( | ||
| | @@ -2059,15 +2265,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( | |
| PatternBenefit readBenefit) { | ||
| patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); | ||
| patterns | ||
| .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, | ||
| WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, | ||
| WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, | ||
| WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>( | ||
| .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract, | ||
| WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar, | ||
| WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice, | ||
| WarpOpInsertStridedSlice, WarpOpMultiReduction, WarpOpStep>( | ||
| patterns.getContext(), benefit); | ||
| patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, | ||
| benefit); | ||
| patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, | ||
| benefit); | ||
| patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(), | ||
| distributionMapFn, benefit); | ||
| } | ||
| | ||
| void mlir::vector::populateDistributeReduction( | ||
| | ||
Uh oh!
There was an error while loading. Please reload this page.