1010//
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1315#include " mlir/Dialect/Arith/IR/Arith.h"
1416#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1517#include " mlir/Dialect/GPU/Transforms/Passes.h"
1618#include " mlir/Dialect/GPU/Utils/GPUUtils.h"
19+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
1720#include " mlir/Dialect/Vector/IR/VectorOps.h"
1821#include " mlir/IR/BuiltinTypes.h"
1922#include " mlir/IR/Location.h"
2023#include " mlir/IR/PatternMatch.h"
2124#include " mlir/IR/TypeUtilities.h"
25+ #include " llvm/Support/ErrorHandling.h"
2226#include " llvm/Support/FormatVariadic.h"
2327#include " llvm/Support/MathExtras.h"
2428#include < cassert>
@@ -362,6 +366,163 @@ struct VectorSubgroupReduceToShuffles final
362366 unsigned shuffleBitwidth = 0 ;
363367 bool matchClustered = false ;
364368};
369+
370+ static FailureOr<Value>
371+ createSubgroupDPPReduction (PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372+ Value input, gpu::AllReduceOperation mode,
373+ const ClusterInfo &ci, amdgpu::Chipset chipset) {
374+ Location loc = op.getLoc ();
375+ Value dpp;
376+ Value res = input;
377+ constexpr int allRows = 0xf ;
378+ constexpr int allBanks = 0xf ;
379+ const bool boundCtrl = true ;
380+ if (ci.clusterSize >= 2 ) {
381+ // Perform reduction between all lanes N <-> N+1.
382+ dpp = rewriter.create <amdgpu::DPPOp>(
383+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
384+ rewriter.getI32ArrayAttr ({1 , 0 , 3 , 2 }), allRows, allBanks, boundCtrl);
385+ res = vector::makeArithReduction (rewriter, loc,
386+ gpu::convertReductionKind (mode), res, dpp);
387+ }
388+
389+ if (ci.clusterSize >= 4 ) {
390+ // Perform reduction between all lanes N <-> N+2.
391+ dpp = rewriter.create <amdgpu::DPPOp>(
392+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
393+ rewriter.getI32ArrayAttr ({2 , 3 , 0 , 1 }), allRows, allBanks, boundCtrl);
394+ res = vector::makeArithReduction (rewriter, loc,
395+ gpu::convertReductionKind (mode), res, dpp);
396+ }
397+ if (ci.clusterSize >= 8 ) {
398+ // Perform reduction between all lanes N <-> 7-N,
399+ // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400+ dpp = rewriter.create <amdgpu::DPPOp>(
401+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_half_mirror,
402+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
403+ res = vector::makeArithReduction (rewriter, loc,
404+ gpu::convertReductionKind (mode), res, dpp);
405+ }
406+ if (ci.clusterSize >= 16 ) {
407+ // Perform reduction between all lanes N <-> 15-N,
408+ // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409+ dpp = rewriter.create <amdgpu::DPPOp>(
410+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_mirror,
411+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
412+ res = vector::makeArithReduction (rewriter, loc,
413+ gpu::convertReductionKind (mode), res, dpp);
414+ }
415+ if (ci.clusterSize >= 32 ) {
416+ if (chipset.majorVersion <= 9 ) {
417+ // Broadcast last value from each row to next row.
418+ // Use row mask to avoid polluting rows 1 and 3.
419+ dpp = rewriter.create <amdgpu::DPPOp>(
420+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_15,
421+ rewriter.getUnitAttr (), 0xa , allBanks,
422+ /* bound_ctrl*/ false );
423+ res = vector::makeArithReduction (
424+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
425+ } else if (chipset.majorVersion <= 12 ) {
426+ // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427+ Value uint32Max = rewriter.create <arith::ConstantOp>(
428+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (-1 ));
429+ dpp = rewriter.create <ROCDL::PermlaneX16Op>(loc, res.getType (), res, res,
430+ uint32Max, uint32Max,
431+ /* fi=*/ true ,
432+ /* bound_ctrl=*/ false );
433+ res = vector::makeArithReduction (
434+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
435+ if (ci.subgroupSize == 32 ) {
436+ Value lane0 = rewriter.create <arith::ConstantOp>(
437+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
438+ res =
439+ rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
440+ }
441+ } else {
442+ return rewriter.notifyMatchFailure (
443+ op, " Subgroup reduce lowering to DPP not currently supported for "
444+ " this device." );
445+ }
446+ }
447+ if (ci.clusterSize >= 64 ) {
448+ if (chipset.majorVersion <= 9 ) {
449+ // Broadcast 31st lane value to rows 2 and 3.
450+ // Use row mask to avoid polluting rows 0 and 1.
451+ dpp = rewriter.create <amdgpu::DPPOp>(
452+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_31,
453+ rewriter.getUnitAttr (), 0xc , allBanks,
454+ /* bound_ctrl*/ false );
455+
456+ } else if (chipset.majorVersion <= 12 ) {
457+ // Assume reduction across 32 lanes has been done.
458+ // Perform final reduction manually by summing values in lane 0 and
459+ // lane 32.
460+ Value lane0 = rewriter.create <arith::ConstantOp>(
461+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
462+ Value lane32 = rewriter.create <arith::ConstantOp>(
463+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (32 ));
464+ dpp = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane32);
465+ res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
466+ } else {
467+ return rewriter.notifyMatchFailure (
468+ op, " Subgroup reduce lowering to DPP not currently supported for "
469+ " this device." );
470+ }
471+ res = vector::makeArithReduction (rewriter, loc,
472+ gpu::convertReductionKind (mode), res, dpp);
473+ }
474+ assert (res.getType () == input.getType ());
475+ return res;
476+ }
477+
478+ // / Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
479+ // / ops over scalar types. Assumes that the subgroup has
480+ // / `subgroupSize` lanes. Applicable only to AMD GPUs.
481+ struct ScalarSubgroupReduceToDPP final
482+ : OpRewritePattern<gpu::SubgroupReduceOp> {
483+ ScalarSubgroupReduceToDPP (MLIRContext *ctx, unsigned subgroupSize,
484+ bool matchClustered, amdgpu::Chipset chipset,
485+ PatternBenefit benefit)
486+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
487+ matchClustered (matchClustered), chipset(chipset) {}
488+
489+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
490+ PatternRewriter &rewriter) const override {
491+ if (op.getClusterSize ().has_value () != matchClustered) {
492+ return rewriter.notifyMatchFailure (
493+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
494+ " only match {1}clustered ops" ,
495+ matchClustered ? " non-" : " " ,
496+ matchClustered ? " " : " non-" ));
497+ }
498+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
499+ if (failed (ci))
500+ return failure ();
501+
502+ if (ci->clusterStride != 1 )
503+ return rewriter.notifyMatchFailure (
504+ op, " Subgroup reductions using DPP are currently only available for "
505+ " clusters of contiguous lanes." );
506+
507+ Type valueTy = op.getType ();
508+ if (!valueTy.isIntOrFloat ())
509+ return rewriter.notifyMatchFailure (
510+ op, " Value type is not a compatible scalar." );
511+
512+ FailureOr<Value> dpp = createSubgroupDPPReduction (
513+ rewriter, op, op.getValue (), op.getOp (), *ci, chipset);
514+ if (failed (dpp))
515+ return failure ();
516+
517+ rewriter.replaceOp (op, dpp.value ());
518+ return success ();
519+ }
520+
521+ private:
522+ unsigned subgroupSize = 0 ;
523+ bool matchClustered = false ;
524+ amdgpu::Chipset chipset;
525+ };
365526} // namespace
366527
367528void mlir::populateGpuBreakDownSubgroupReducePatterns (
@@ -372,6 +533,22 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372533 patterns.add <ScalarizeSingleElementReduce>(patterns.getContext (), benefit);
373534}
374535
536+ void mlir::populateGpuLowerSubgroupReduceToDPPPatterns (
537+ RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
538+ PatternBenefit benefit) {
539+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
540+ /* matchClustered=*/ false , chipset,
541+ benefit);
542+ }
543+
544+ void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns (
545+ RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
546+ PatternBenefit benefit) {
547+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
548+ /* matchClustered=*/ true , chipset,
549+ benefit);
550+ }
551+
375552void mlir::populateGpuLowerSubgroupReduceToShufflePatterns (
376553 RewritePatternSet &patterns, unsigned subgroupSize,
377554 unsigned shuffleBitwidth, PatternBenefit benefit) {
0 commit comments