Skip to content

Conversation

@dcaballe
Copy link
Contributor

This PR adds support for memref.collapse_shape to sub-byte type emulation. The memref.collapse_shape becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).

This PR add support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds support for memref.collapse_shape to sub-byte type emulation. The memref.collapse_shape becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).


Full diff: https://github.com/llvm/llvm-project/pull/89962.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+29-3)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+20)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 4449733f0daf06..77c108aab48070 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -24,7 +23,6 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" #include <cassert> #include <type_traits> @@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> { } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefCollapseShape +//===----------------------------------------------------------------------===// + +/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given +/// that we flatten memrefs to a single dimension as part of the emulation and +/// there is no dimension to collapse any further. +struct ConvertMemRefCollapseShape final + : OpConversionPattern<memref::CollapseShapeOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcVal = adaptor.getSrc(); + auto newTy = dyn_cast<MemRefType>(srcVal.getType()); + if (!newTy) + return failure(); + + if (newTy.getRank() != 1) + return failure(); + + rewriter.replaceOp(collapseShapeOp, srcVal); + return success(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( // Populate `memref.*` conversion patterns. patterns.add<ConvertMemRefAllocation<memref::AllocOp>, - ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad, + ConvertMemRefAllocation<memref::AllocaOp>, + ConvertMemRefCollapseShape, ConvertMemRefLoad, ConvertMemrefStore, ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemRefReinterpretCast>( typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index fd37b7ff0a2713..435dcc944778db 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () { // CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32 // CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32 // CHECK32: return + +// ----- + +func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 { + %arr = memref.alloc() : memref<32x8x128xi4> + %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4> + %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4> + return %1 : i4 +} + +// CHECK-LABEL: func.func @memref_collapse_shape_i4( +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8> +// CHECK-NOT: memref.collapse_shape +// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8> + +// CHECK32-LABEL: func.func @memref_collapse_shape_i4( +// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32> +// CHECK32-NOT: memref.collapse_shape +// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32> + 
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thank you!

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.
@dcaballe dcaballe merged commit 571831a into llvm:main Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants