1313#include " mlir/Dialect/Arith/Transforms/Passes.h"
1414#include " mlir/Dialect/Arith/Utils/Utils.h"
1515#include " mlir/Dialect/MemRef/IR/MemRef.h"
16- #include " mlir/Dialect/MemRef/Transforms/Passes.h"
1716#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
1817#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1918#include " mlir/Dialect/Vector/IR/VectorOps.h"
2423#include " mlir/Support/MathExtras.h"
2524#include " mlir/Transforms/DialectConversion.h"
2625#include " llvm/Support/FormatVariadic.h"
27- #include " llvm/Support/MathExtras.h"
2826#include < cassert>
2927#include < type_traits>
3028
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
430428 }
431429};
432430
431+ // ===----------------------------------------------------------------------===//
432+ // ConvertMemRefCollapseShape
433+ // ===----------------------------------------------------------------------===//
434+
435+ // / Emulating a `memref.collapse_shape` becomes a no-op after emulation given
436+ // / that we flatten memrefs to a single dimension as part of the emulation and
437+ // / there is no dimension to collapse any further.
438+ struct ConvertMemRefCollapseShape final
439+ : OpConversionPattern<memref::CollapseShapeOp> {
440+ using OpConversionPattern::OpConversionPattern;
441+
442+ LogicalResult
443+ matchAndRewrite (memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
444+ ConversionPatternRewriter &rewriter) const override {
445+ Value srcVal = adaptor.getSrc ();
446+ auto newTy = dyn_cast<MemRefType>(srcVal.getType ());
447+ if (!newTy)
448+ return failure ();
449+
450+ if (newTy.getRank () != 1 )
451+ return failure ();
452+
453+ rewriter.replaceOp (collapseShapeOp, srcVal);
454+ return success ();
455+ }
456+ };
457+
433458} // end anonymous namespace
434459
435460// ===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
442467
443468 // Populate `memref.*` conversion patterns.
444469 patterns.add <ConvertMemRefAllocation<memref::AllocOp>,
445- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
470+ ConvertMemRefAllocation<memref::AllocaOp>,
471+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
446472 ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447473 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448474 typeConverter, patterns.getContext ());
0 commit comments